diff --git a/ollama/__init__.py b/ollama/__init__.py index c452f71..23d736a 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -1,10 +1,17 @@ from ollama._client import Client, AsyncClient from ollama._types import ( + Options, + Message, + Tool, GenerateResponse, ChatResponse, + EmbedResponse, + EmbeddingsResponse, + StatusResponse, ProgressResponse, - Message, - Options, + ListResponse, + ShowResponse, + ProcessResponse, RequestError, ResponseError, ) @@ -12,25 +19,20 @@ __all__ = [ 'Client', 'AsyncClient', + 'Options', + 'Message', + 'Tool', 'GenerateResponse', 'ChatResponse', + 'EmbedResponse', + 'EmbeddingsResponse', + 'StatusResponse', 'ProgressResponse', - 'Message', - 'Options', + 'ListResponse', + 'ShowResponse', + 'ProcessResponse', 'RequestError', 'ResponseError', - 'generate', - 'chat', - 'embed', - 'embeddings', - 'pull', - 'push', - 'create', - 'delete', - 'list', - 'copy', - 'show', - 'ps', ] _client = Client() diff --git a/ollama/_client.py b/ollama/_client.py index e3d9fed..c1f5f95 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -1,18 +1,24 @@ -import ipaddress import os import io import json -import httpx -import binascii import platform +import ipaddress import urllib.parse from os import PathLike from pathlib import Path -from copy import deepcopy from hashlib import sha256 -from base64 import b64encode, b64decode -from typing import Any, AnyStr, Union, Optional, Sequence, Mapping, Literal, overload +from typing import ( + Any, + Literal, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + overload, +) import sys @@ -28,7 +34,38 @@ except metadata.PackageNotFoundError: __version__ = '0.0.0' -from ollama._types import Message, Options, RequestError, ResponseError, Tool +import httpx + +from ollama._types import ( + ChatRequest, + ChatResponse, + CreateRequest, + CopyRequest, + DeleteRequest, + EmbedRequest, + EmbedResponse, + EmbeddingsRequest, + EmbeddingsResponse, + GenerateRequest, + GenerateResponse, + Image, + ListResponse, + Message, + Options, + ProcessResponse, + ProgressResponse, + PullRequest, + PushRequest, + RequestError, + ResponseError, + ShowRequest, + ShowResponse, + StatusResponse, + Tool, +) + + +T = TypeVar('T') class BaseClient: @@ -38,6 +75,7 @@ def __init__( host: Optional[str] = None, follow_redirects: bool = True, timeout: Any = None, + headers: Optional[Mapping[str, str]] = None, **kwargs, ) -> None: """ @@ -48,16 +86,15 @@ def __init__( `kwargs` are passed to the httpx client. """ - headers = kwargs.pop('headers', {}) - headers['Content-Type'] = 'application/json' - headers['Accept'] = 'application/json' - headers['User-Agent'] = f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}' - self._client = client( base_url=_parse_host(host or os.getenv('OLLAMA_HOST')), follow_redirects=follow_redirects, timeout=timeout, - headers=headers, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}', + }.update(headers or {}), **kwargs, ) @@ -66,37 +103,67 @@ class Client(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.Client, host, **kwargs) - def _request(self, method: str, url: str, **kwargs) -> httpx.Response: - response = self._client.request(method, url, **kwargs) - + def _request_raw(self, *args, **kwargs): + r = self._client.request(*args, **kwargs) try: - response.raise_for_status() + r.raise_for_status() except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None + return r - return response + @overload + def _request( + self, + cls: Type[T], + *args, + stream: Literal[False] = False, + **kwargs, + ) -> T: ... - def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: - with self._client.stream(method, url, **kwargs) as r: - try: - r.raise_for_status() - except httpx.HTTPStatusError as e: - e.response.read() - raise ResponseError(e.response.text, e.response.status_code) from None + @overload + def _request( + self, + cls: Type[T], + *args, + stream: Literal[True] = True, + **kwargs, + ) -> Iterator[T]: ... - for line in r.iter_lines(): - partial = json.loads(line) - if e := partial.get('error'): - raise ResponseError(e) - yield partial + @overload + def _request( + self, + cls: Type[T], + *args, + stream: bool = False, + **kwargs, + ) -> Union[T, Iterator[T]]: ... - def _request_stream( + def _request( self, + cls: Type[T], *args, stream: bool = False, **kwargs, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json() + ) -> Union[T, Iterator[T]]: + if stream: + + def inner(): + with self._client.stream(*args, **kwargs) as r: + try: + r.raise_for_status() + except httpx.HTTPStatusError as e: + e.response.read() + raise ResponseError(e.response.text, e.response.status_code) from None + + for line in r.iter_lines(): + part = json.loads(line) + if err := part.get('error'): + raise ResponseError(err) + yield cls(**part) + + return inner() + + return cls(**self._request_raw(*args, **kwargs).json()) @overload def generate( @@ -104,16 +171,17 @@ def generate( model: str = '', prompt: str = '', suffix: str = '', + *, system: str = '', template: str = '', context: Optional[Sequence[int]] = None, stream: Literal[False] = False, raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: ... + ) -> GenerateResponse: ... @overload def generate( @@ -121,32 +189,34 @@ def generate( model: str = '', prompt: str = '', suffix: str = '', + *, system: str = '', template: str = '', context: Optional[Sequence[int]] = None, stream: Literal[True] = True, raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Iterator[Mapping[str, Any]]: ... + ) -> Iterator[GenerateResponse]: ... def generate( self, model: str = '', - prompt: str = '', - suffix: str = '', - system: str = '', - template: str = '', + prompt: Optional[str] = None, + suffix: Optional[str] = None, + *, + system: Optional[str] = None, + template: Optional[str] = None, context: Optional[Sequence[int]] = None, stream: bool = False, - raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + raw: Optional[bool] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + ) -> Union[GenerateResponse, Iterator[GenerateResponse]]: """ Create a response using the requested model. @@ -157,26 +227,24 @@ def generate( Returns `GenerateResponse` if `stream` is `False`, otherwise returns a `GenerateResponse` generator. """ - if not model: - raise RequestError('must provide a model') - - return self._request_stream( + return self._request( + GenerateResponse, 'POST', '/api/generate', - json={ - 'model': model, - 'prompt': prompt, - 'suffix': suffix, - 'system': system, - 'template': template, - 'context': context or [], - 'stream': stream, - 'raw': raw, - 'images': [_encode_image(image) for image in images or []], - 'format': format, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=GenerateRequest( + model=model, + prompt=prompt, + suffix=suffix, + system=system, + template=template, + context=context, + stream=stream, + raw=raw, + format=format, + images=[Image(value=image) for image in images] if images else None, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, ) @@ -184,36 +252,39 @@ def generate( def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: Literal[False] = False, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: ... + ) -> ChatResponse: ... @overload def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: Literal[True] = True, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Iterator[Mapping[str, Any]]: ... + ) -> Iterator[ChatResponse]: ... def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: bool = False, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + ) -> Union[ChatResponse, Iterator[ChatResponse]]: """ Create a chat response using the requested model. @@ -224,109 +295,104 @@ def chat( Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator. """ - if not model: - raise RequestError('must provide a model') - - messages = deepcopy(messages) - - for message in messages or []: - if images := message.get('images'): - message['images'] = [_encode_image(image) for image in images] - - return self._request_stream( + return self._request( + ChatResponse, 'POST', '/api/chat', - json={ - 'model': model, - 'messages': messages, - 'tools': tools or [], - 'stream': stream, - 'format': format, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=ChatRequest( + model=model, + messages=[message for message in _copy_messages(messages)], + tools=[tool for tool in _copy_tools(tools)], + stream=stream, + format=format, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, ) def embed( self, model: str = '', - input: Union[str, Sequence[AnyStr]] = '', - truncate: bool = True, - options: Optional[Options] = None, + input: Union[str, Sequence[str]] = '', + truncate: Optional[bool] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: - if not model: - raise RequestError('must provide a model') - + ) -> EmbedResponse: return self._request( + EmbedResponse, 'POST', '/api/embed', - json={ - 'model': model, - 'input': input, - 'truncate': truncate, - 'options': options or {}, - 'keep_alive': keep_alive, - }, - ).json() + json=EmbedRequest( + model=model, + input=input, + truncate=truncate, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), + ) def embeddings( self, model: str = '', - prompt: str = '', - options: Optional[Options] = None, + prompt: Optional[str] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Sequence[float]]: + ) -> EmbeddingsResponse: """ Deprecated in favor of `embed`. """ return self._request( + EmbeddingsResponse, 'POST', '/api/embeddings', - json={ - 'model': model, - 'prompt': prompt, - 'options': options or {}, - 'keep_alive': keep_alive, - }, - ).json() + json=EmbeddingsRequest( + model=model, + prompt=prompt, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), + ) @overload def pull( self, model: str, + *, insecure: bool = False, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload def pull( self, model: str, + *, insecure: bool = False, stream: Literal[True] = True, - ) -> Iterator[Mapping[str, Any]]: ... + ) -> Iterator[ProgressResponse]: ... def pull( self, model: str, + *, insecure: bool = False, stream: bool = False, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, Iterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ - return self._request_stream( + return self._request( + ProgressResponse, 'POST', '/api/pull', - json={ - 'name': model, - 'insecure': insecure, - 'stream': stream, - }, + json=PullRequest( + model=model, + insecure=insecure, + stream=stream, + ).model_dump(exclude_none=True), stream=stream, ) @@ -334,37 +400,41 @@ def pull( def push( self, model: str, + *, insecure: bool = False, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload def push( self, model: str, + *, insecure: bool = False, stream: Literal[True] = True, - ) -> Iterator[Mapping[str, Any]]: ... + ) -> Iterator[ProgressResponse]: ... def push( self, model: str, + *, insecure: bool = False, stream: bool = False, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, Iterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ - return self._request_stream( + return self._request( + ProgressResponse, 'POST', '/api/push', - json={ - 'name': model, - 'insecure': insecure, - 'stream': stream, - }, + json=PushRequest( + model=model, + insecure=insecure, + stream=stream, + ).model_dump(exclude_none=True), stream=stream, ) @@ -374,9 +444,10 @@ def create( model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload def create( @@ -384,18 +455,20 @@ def create( model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: Literal[True] = True, - ) -> Iterator[Mapping[str, Any]]: ... + ) -> Iterator[ProgressResponse]: ... def create( self, model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: bool = False, - ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, Iterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. @@ -408,15 +481,16 @@ def create( else: raise RequestError('must provide either path or modelfile') - return self._request_stream( + return self._request( + ProgressResponse, 'POST', '/api/create', - json={ - 'name': model, - 'modelfile': modelfile, - 'stream': stream, - 'quantize': quantize, - }, + json=CreateRequest( + model=model, + modelfile=modelfile, + stream=stream, + quantize=quantize, + ).model_dump(exclude_none=True), stream=stream, ) @@ -450,76 +524,131 @@ def _create_blob(self, path: Union[str, Path]) -> str: digest = f'sha256:{sha256sum.hexdigest()}' try: - self._request('HEAD', f'/api/blobs/{digest}') + self._request_raw('HEAD', f'/api/blobs/{digest}') except ResponseError as e: if e.status_code != 404: raise with open(path, 'rb') as r: - self._request('POST', f'/api/blobs/{digest}', content=r) + self._request_raw('POST', f'/api/blobs/{digest}', content=r) return digest - def delete(self, model: str) -> Mapping[str, Any]: - response = self._request('DELETE', '/api/delete', json={'name': model}) - return {'status': 'success' if response.status_code == 200 else 'error'} + def list(self) -> ListResponse: + return self._request( + ListResponse, + 'GET', + '/api/tags', + ) - def list(self) -> Mapping[str, Any]: - return self._request('GET', '/api/tags').json() + def delete(self, model: str) -> StatusResponse: + r = self._request_raw( + 'DELETE', + '/api/delete', + json=DeleteRequest( + model=model, + ).model_dump(exclude_none=True), + ) + return StatusResponse( + status='success' if r.status_code == 200 else 'error', + ) - def copy(self, source: str, destination: str) -> Mapping[str, Any]: - response = self._request('POST', '/api/copy', json={'source': source, 'destination': destination}) - return {'status': 'success' if response.status_code == 200 else 'error'} + def copy(self, source: str, destination: str) -> StatusResponse: + r = self._request_raw( + 'POST', + '/api/copy', + json=CopyRequest( + source=source, + destination=destination, + ).model_dump(exclude_none=True), + ) + return StatusResponse( + status='success' if r.status_code == 200 else 'error', + ) - def show(self, model: str) -> Mapping[str, Any]: - return self._request('POST', '/api/show', json={'name': model}).json() + def show(self, model: str) -> ShowResponse: + return self._request( + ShowResponse, + 'POST', + '/api/show', + json=ShowRequest( + model=model, + ).model_dump(exclude_none=True), + ) - def ps(self) -> Mapping[str, Any]: - return self._request('GET', '/api/ps').json() + def ps(self) -> ProcessResponse: + return self._request( + ProcessResponse, + 'GET', + '/api/ps', + ) class AsyncClient(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.AsyncClient, host, **kwargs) - async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: - response = await self._client.request(method, url, **kwargs) - + async def _request_raw(self, *args, **kwargs): + r = await self._client.request(*args, **kwargs) try: - response.raise_for_status() + r.raise_for_status() except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None + return r - return response - - async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: - async def inner(): - async with self._client.stream(method, url, **kwargs) as r: - try: - r.raise_for_status() - except httpx.HTTPStatusError as e: - await e.response.aread() - raise ResponseError(e.response.text, e.response.status_code) from None + @overload + async def _request( + self, + cls: Type[T], + *args, + stream: Literal[False] = False, + **kwargs, + ) -> T: ... - async for line in r.aiter_lines(): - partial = json.loads(line) - if e := partial.get('error'): - raise ResponseError(e) - yield partial + @overload + async def _request( + self, + cls: Type[T], + *args, + stream: Literal[True] = True, + **kwargs, + ) -> AsyncIterator[T]: ... - return inner() + @overload + async def _request( + self, + cls: Type[T], + *args, + stream: bool = False, + **kwargs, + ) -> Union[T, AsyncIterator[T]]: ... - async def _request_stream( + async def _request( self, + cls: Type[T], *args, stream: bool = False, **kwargs, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[T, AsyncIterator[T]]: if stream: - return await self._stream(*args, **kwargs) - response = await self._request(*args, **kwargs) - return response.json() + async def inner(): + async with self._client.stream(*args, **kwargs) as r: + try: + r.raise_for_status() + except httpx.HTTPStatusError as e: + await e.response.aread() + raise ResponseError(e.response.text, e.response.status_code) from None + + async for line in r.aiter_lines(): + part = json.loads(line) + if err := part.get('error'): + raise ResponseError(err) + yield cls(**part) + + return inner() + + return cls(**(await self._request_raw(*args, **kwargs)).json()) @overload async def generate( @@ -527,16 +656,17 @@ async def generate( model: str = '', prompt: str = '', suffix: str = '', + *, system: str = '', template: str = '', context: Optional[Sequence[int]] = None, stream: Literal[False] = False, raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: ... + ) -> GenerateResponse: ... @overload async def generate( @@ -544,32 +674,34 @@ async def generate( model: str = '', prompt: str = '', suffix: str = '', + *, system: str = '', template: str = '', context: Optional[Sequence[int]] = None, stream: Literal[True] = True, raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> AsyncIterator[Mapping[str, Any]]: ... + ) -> AsyncIterator[GenerateResponse]: ... async def generate( self, model: str = '', - prompt: str = '', - suffix: str = '', - system: str = '', - template: str = '', + prompt: Optional[str] = None, + suffix: Optional[str] = None, + *, + system: Optional[str] = None, + template: Optional[str] = None, context: Optional[Sequence[int]] = None, stream: bool = False, - raw: bool = False, - format: Literal['', 'json'] = '', - images: Optional[Sequence[AnyStr]] = None, - options: Optional[Options] = None, + raw: Optional[bool] = None, + format: Optional[Literal['', 'json']] = None, + images: Optional[Sequence[Union[str, bytes]]] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]: """ Create a response using the requested model. @@ -579,26 +711,24 @@ async def generate( Returns `GenerateResponse` if `stream` is `False`, otherwise returns an asynchronous `GenerateResponse` generator. """ - if not model: - raise RequestError('must provide a model') - - return await self._request_stream( + return await self._request( + GenerateResponse, 'POST', '/api/generate', - json={ - 'model': model, - 'prompt': prompt, - 'suffix': suffix, - 'system': system, - 'template': template, - 'context': context or [], - 'stream': stream, - 'raw': raw, - 'images': [_encode_image(image) for image in images or []], - 'format': format, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=GenerateRequest( + model=model, + prompt=prompt, + suffix=suffix, + system=system, + template=template, + context=context, + stream=stream, + raw=raw, + format=format, + images=[Image(value=image) for image in images] if images else None, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, ) @@ -606,36 +736,39 @@ async def generate( async def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: Literal[False] = False, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: ... + ) -> ChatResponse: ... @overload async def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: Literal[True] = True, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> AsyncIterator[Mapping[str, Any]]: ... + ) -> AsyncIterator[ChatResponse]: ... async def chat( self, model: str = '', - messages: Optional[Sequence[Message]] = None, - tools: Optional[Sequence[Tool]] = None, + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, + *, + tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None, stream: bool = False, - format: Literal['', 'json'] = '', - options: Optional[Options] = None, + format: Optional[Literal['', 'json']] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: """ Create a chat response using the requested model. @@ -645,113 +778,105 @@ async def chat( Returns `ChatResponse` if `stream` is `False`, otherwise returns an asynchronous `ChatResponse` generator. """ - if not model: - raise RequestError('must provide a model') - - messages = deepcopy(messages) - - for message in messages or []: - if images := message.get('images'): - message['images'] = [_encode_image(image) for image in images] - return await self._request_stream( + return await self._request( + ChatResponse, 'POST', '/api/chat', - json={ - 'model': model, - 'messages': messages, - 'tools': tools or [], - 'stream': stream, - 'format': format, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=ChatRequest( + model=model, + messages=[message for message in _copy_messages(messages)], + tools=[tool for tool in _copy_tools(tools)], + stream=stream, + format=format, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), stream=stream, ) async def embed( self, model: str = '', - input: Union[str, Sequence[AnyStr]] = '', - truncate: bool = True, - options: Optional[Options] = None, + input: Union[str, Sequence[str]] = '', + truncate: Optional[bool] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Any]: - if not model: - raise RequestError('must provide a model') - - response = await self._request( + ) -> EmbedResponse: + return await self._request( + EmbedResponse, 'POST', '/api/embed', - json={ - 'model': model, - 'input': input, - 'truncate': truncate, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=EmbedRequest( + model=model, + input=input, + truncate=truncate, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), ) - return response.json() - async def embeddings( self, model: str = '', - prompt: str = '', - options: Optional[Options] = None, + prompt: Optional[str] = None, + options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, - ) -> Mapping[str, Sequence[float]]: + ) -> EmbeddingsResponse: """ Deprecated in favor of `embed`. """ - response = await self._request( + return await self._request( + EmbeddingsResponse, 'POST', '/api/embeddings', - json={ - 'model': model, - 'prompt': prompt, - 'options': options or {}, - 'keep_alive': keep_alive, - }, + json=EmbeddingsRequest( + model=model, + prompt=prompt, + options=options, + keep_alive=keep_alive, + ).model_dump(exclude_none=True), ) - return response.json() - @overload async def pull( self, model: str, + *, insecure: bool = False, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload async def pull( self, model: str, + *, insecure: bool = False, stream: Literal[True] = True, - ) -> AsyncIterator[Mapping[str, Any]]: ... + ) -> AsyncIterator[ProgressResponse]: ... async def pull( self, model: str, + *, insecure: bool = False, stream: bool = False, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ - return await self._request_stream( + return await self._request( + ProgressResponse, 'POST', '/api/pull', - json={ - 'name': model, - 'insecure': insecure, - 'stream': stream, - }, + json=PullRequest( + model=model, + insecure=insecure, + stream=stream, + ).model_dump(exclude_none=True), stream=stream, ) @@ -759,37 +884,41 @@ async def pull( async def push( self, model: str, + *, insecure: bool = False, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload async def push( self, model: str, + *, insecure: bool = False, stream: Literal[True] = True, - ) -> AsyncIterator[Mapping[str, Any]]: ... + ) -> AsyncIterator[ProgressResponse]: ... async def push( self, model: str, + *, insecure: bool = False, stream: bool = False, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator. """ - return await self._request_stream( + return await self._request( + ProgressResponse, 'POST', '/api/push', - json={ - 'name': model, - 'insecure': insecure, - 'stream': stream, - }, + json=PushRequest( + model=model, + insecure=insecure, + stream=stream, + ).model_dump(exclude_none=True), stream=stream, ) @@ -799,9 +928,10 @@ async def create( model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: Literal[False] = False, - ) -> Mapping[str, Any]: ... + ) -> ProgressResponse: ... @overload async def create( @@ -809,18 +939,20 @@ async def create( model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: Literal[True] = True, - ) -> AsyncIterator[Mapping[str, Any]]: ... + ) -> AsyncIterator[ProgressResponse]: ... async def create( self, model: str, path: Optional[Union[str, PathLike]] = None, modelfile: Optional[str] = None, + *, quantize: Optional[str] = None, stream: bool = False, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + ) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]: """ Raises `ResponseError` if the request could not be fulfilled. @@ -833,15 +965,16 @@ async def create( else: raise RequestError('must provide either path or modelfile') - return await self._request_stream( + return await self._request( + ProgressResponse, 'POST', '/api/create', - json={ - 'name': model, - 'modelfile': modelfile, - 'stream': stream, - 'quantize': quantize, - }, + json=CreateRequest( + model=model, + modelfile=modelfile, + stream=stream, + quantize=quantize, + ).model_dump(exclude_none=True), stream=stream, ) @@ -875,7 +1008,7 @@ async def _create_blob(self, path: Union[str, Path]) -> str: digest = f'sha256:{sha256sum.hexdigest()}' try: - await self._request('HEAD', f'/api/blobs/{digest}') + await self._request_raw('HEAD', f'/api/blobs/{digest}') except ResponseError as e: if e.status_code != 404: raise @@ -888,60 +1021,70 @@ async def upload_bytes(): break yield chunk - await self._request('POST', f'/api/blobs/{digest}', content=upload_bytes()) + await self._request_raw('POST', f'/api/blobs/{digest}', content=upload_bytes()) return digest - async def delete(self, model: str) -> Mapping[str, Any]: - response = await self._request('DELETE', '/api/delete', json={'name': model}) - return {'status': 'success' if response.status_code == 200 else 'error'} - - async def list(self) -> Mapping[str, Any]: - response = await self._request('GET', '/api/tags') - return response.json() - - async def copy(self, source: str, destination: str) -> Mapping[str, Any]: - response = await self._request('POST', '/api/copy', json={'source': source, 'destination': destination}) - return {'status': 'success' if response.status_code == 200 else 'error'} + async def list(self) -> ListResponse: + return await self._request( + ListResponse, + 'GET', + '/api/tags', + ) - async def show(self, model: str) -> Mapping[str, Any]: - response = await self._request('POST', '/api/show', json={'name': model}) - return response.json() + async def delete(self, model: str) -> StatusResponse: + r = await self._request_raw( + 'DELETE', + '/api/delete', + json=DeleteRequest( + model=model, + ).model_dump(exclude_none=True), + ) + return StatusResponse( + status='success' if r.status_code == 200 else 'error', + ) - async def ps(self) -> Mapping[str, Any]: - response = await self._request('GET', '/api/ps') - return response.json() + async def copy(self, source: str, destination: str) -> StatusResponse: + r = await self._request_raw( + 'POST', + '/api/copy', + json=CopyRequest( + source=source, + destination=destination, + ).model_dump(exclude_none=True), + ) + return StatusResponse( + status='success' if r.status_code == 200 else 'error', + ) + async def show(self, model: str) -> ShowResponse: + return await self._request( + ShowResponse, + 'POST', + '/api/show', + json=ShowRequest( + model=model, + ).model_dump(exclude_none=True), + ) -def _encode_image(image) -> str: - """ - >>> _encode_image(b'ollama') - 'b2xsYW1h' - >>> _encode_image(io.BytesIO(b'ollama')) - 'b2xsYW1h' - >>> _encode_image('LICENSE') - 'TUlUIExpY2Vuc2UKCkNvcHlyaWdodCAoYykgT2xsYW1hCgpQZXJtaXNzaW9uIGlzIGhlcmVieSBncmFudGVkLCBmcmVlIG9mIGNoYXJnZSwgdG8gYW55IHBlcnNvbiBvYnRhaW5pbmcgYSBjb3B5Cm9mIHRoaXMgc29mdHdhcmUgYW5kIGFzc29jaWF0ZWQgZG9jdW1lbnRhdGlvbiBmaWxlcyAodGhlICJTb2Z0d2FyZSIpLCB0byBkZWFsCmluIHRoZSBTb2Z0d2FyZSB3aXRob3V0IHJlc3RyaWN0aW9uLCBpbmNsdWRpbmcgd2l0aG91dCBsaW1pdGF0aW9uIHRoZSByaWdodHMKdG8gdXNlLCBjb3B5LCBtb2RpZnksIG1lcmdlLCBwdWJsaXNoLCBkaXN0cmlidXRlLCBzdWJsaWNlbnNlLCBhbmQvb3Igc2VsbApjb3BpZXMgb2YgdGhlIFNvZnR3YXJlLCBhbmQgdG8gcGVybWl0IHBlcnNvbnMgdG8gd2hvbSB0aGUgU29mdHdhcmUgaXMKZnVybmlzaGVkIHRvIGRvIHNvLCBzdWJqZWN0IHRvIHRoZSBmb2xsb3dpbmcgY29uZGl0aW9uczoKClRoZSBhYm92ZSBjb3B5cmlnaHQgbm90aWNlIGFuZCB0aGlzIHBlcm1pc3Npb24gbm90aWNlIHNoYWxsIGJlIGluY2x1ZGVkIGluIGFsbApjb3BpZXMgb3Igc3Vic3RhbnRpYWwgcG9ydGlvbnMgb2YgdGhlIFNvZnR3YXJlLgoKVEhFIFNPRlRXQVJFIElTIFBST1ZJREVEICJBUyBJUyIsIFdJVEhPVVQgV0FSUkFOVFkgT0YgQU5ZIEtJTkQsIEVYUFJFU1MgT1IKSU1QTElFRCwgSU5DTFVESU5HIEJVVCBOT1QgTElNSVRFRCBUTyBUSEUgV0FSUkFOVElFUyBPRiBNRVJDSEFOVEFCSUxJVFksCkZJVE5FU1MgRk9SIEEgUEFSVElDVUxBUiBQVVJQT1NFIEFORCBOT05JTkZSSU5HRU1FTlQuIElOIE5PIEVWRU5UIFNIQUxMIFRIRQpBVVRIT1JTIE9SIENPUFlSSUdIVCBIT0xERVJTIEJFIExJQUJMRSBGT1IgQU5ZIENMQUlNLCBEQU1BR0VTIE9SIE9USEVSCkxJQUJJTElUWSwgV0hFVEhFUiBJTiBBTiBBQ1RJT04gT0YgQ09OVFJBQ1QsIFRPUlQgT1IgT1RIRVJXSVNFLCBBUklTSU5HIEZST00sCk9VVCBPRiBPUiBJTiBDT05ORUNUSU9OIFdJVEggVEhFIFNPRlRXQVJFIE9SIFRIRSBVU0UgT1IgT1RIRVIgREVBTElOR1MgSU4gVEhFClNPRlRXQVJFLgo=' - >>> _encode_image(Path('LICENSE')) - 'TUlUIExpY2Vuc2UKCkNvcHlyaWdodCAoYykgT2xsYW1hCgpQZXJtaXNzaW9uIGlzIGhlcmVieSBncmFudGVkLCBmcmVlIG9mIGNoYXJnZSwgdG8gYW55IHBlcnNvbiBvYnRhaW5pbmcgYSBjb3B5Cm9mIHRoaXMgc29mdHdhcmUgYW5kIGFzc29jaWF0ZWQgZG9jdW1lbnRhdGlvbiBmaWxlcyAodGhlICJTb2Z0d2FyZSIpLCB0byBkZWFsCmluIHRoZSBTb2Z0d2FyZSB3aXRob3V0IHJlc3RyaWN0aW9uLCBpbmNsdWRpbmcgd2l0aG91dCBsaW1pdGF0aW9uIHRoZSByaWdodHMKdG8gdXNlLCBjb3B5LCBtb2RpZnksIG1lcmdlLCBwdWJsaXNoLCBkaXN0cmlidXRlLCBzdWJsaWNlbnNlLCBhbmQvb3Igc2VsbApjb3BpZXMgb2YgdGhlIFNvZnR3YXJlLCBhbmQgdG8gcGVybWl0IHBlcnNvbnMgdG8gd2hvbSB0aGUgU29mdHdhcmUgaXMKZnVybmlzaGVkIHRvIGRvIHNvLCBzdWJqZWN0IHRvIHRoZSBmb2xsb3dpbmcgY29uZGl0aW9uczoKClRoZSBhYm92ZSBjb3B5cmlnaHQgbm90aWNlIGFuZCB0aGlzIHBlcm1pc3Npb24gbm90aWNlIHNoYWxsIGJlIGluY2x1ZGVkIGluIGFsbApjb3BpZXMgb3Igc3Vic3RhbnRpYWwgcG9ydGlvbnMgb2YgdGhlIFNvZnR3YXJlLgoKVEhFIFNPRlRXQVJFIElTIFBST1ZJREVEICJBUyBJUyIsIFdJVEhPVVQgV0FSUkFOVFkgT0YgQU5ZIEtJTkQsIEVYUFJFU1MgT1IKSU1QTElFRCwgSU5DTFVESU5HIEJVVCBOT1QgTElNSVRFRCBUTyBUSEUgV0FSUkFOVElFUyBPRiBNRVJDSEFOVEFCSUxJVFksCkZJVE5FU1MgRk9SIEEgUEFSVElDVUxBUiBQVVJQT1NFIEFORCBOT05JTkZSSU5HRU1FTlQuIElOIE5PIEVWRU5UIFNIQUxMIFRIRQpBVVRIT1JTIE9SIENPUFlSSUdIVCBIT0xERVJTIEJFIExJQUJMRSBGT1IgQU5ZIENMQUlNLCBEQU1BR0VTIE9SIE9USEVSCkxJQUJJTElUWSwgV0hFVEhFUiBJTiBBTiBBQ1RJT04gT0YgQ09OVFJBQ1QsIFRPUlQgT1IgT1RIRVJXSVNFLCBBUklTSU5HIEZST00sCk9VVCBPRiBPUiBJTiBDT05ORUNUSU9OIFdJVEggVEhFIFNPRlRXQVJFIE9SIFRIRSBVU0UgT1IgT1RIRVIgREVBTElOR1MgSU4gVEhFClNPRlRXQVJFLgo=' - >>> _encode_image('YWJj') - 'YWJj' - >>> _encode_image(b'YWJj') - 'YWJj' - """ + async def ps(self) -> ProcessResponse: + return await self._request( + ProcessResponse, + 'GET', + '/api/ps', + ) - if p := _as_path(image): - return b64encode(p.read_bytes()).decode('utf-8') - try: - b64decode(image, validate=True) - return image if isinstance(image, str) else image.decode('utf-8') - except (binascii.Error, TypeError): - ... +def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]]]) -> Iterator[Message]: + for message in messages or []: + yield Message.model_validate( + {k: [Image(value=image) for image in v] if k == 'images' else v for k, v in dict(message).items() if v}, + ) - if b := _as_bytesio(image): - return b64encode(b.read()).decode('utf-8') - raise RequestError('image must be bytes, path-like object, or file-like object') +def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]: + for tool in tools or []: + yield Tool.model_validate(tool) def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: @@ -954,14 +1097,6 @@ def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: return None -def _as_bytesio(s: Any) -> Union[io.BytesIO, None]: - if isinstance(s, io.BytesIO): - return s - elif isinstance(s, bytes): - return io.BytesIO(s) - return None - - def _parse_host(host: Optional[str]) -> str: """ >>> _parse_host(None) @@ -1039,9 +1174,9 @@ def _parse_host(host: Optional[str]) -> str: host = split.hostname or '127.0.0.1' port = split.port or port - # Fix missing square brackets for IPv6 from urlsplit try: if isinstance(ipaddress.ip_address(host), ipaddress.IPv6Address): + # Fix missing square brackets for IPv6 from urlsplit host = f'[{host}]' except ValueError: ... diff --git a/ollama/_types.py b/ollama/_types.py index 7bdcbe4..b223d9c 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -1,43 +1,162 @@ import json -from typing import Any, TypedDict, Sequence, Literal, Mapping +from base64 import b64encode +from pathlib import Path +from datetime import datetime +from typing import ( + Any, + Literal, + Mapping, + Optional, + Sequence, + Union, +) +from typing_extensions import Annotated + +from pydantic import ( + BaseModel, + ByteSize, + Field, + FilePath, + Base64Str, + model_serializer, +) +from pydantic.json_schema import JsonSchemaValue + + +class SubscriptableBaseModel(BaseModel): + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + +class Options(SubscriptableBaseModel): + # load time options + numa: Optional[bool] = None + num_ctx: Optional[int] = None + num_batch: Optional[int] = None + num_gpu: Optional[int] = None + main_gpu: Optional[int] = None + low_vram: Optional[bool] = None + f16_kv: Optional[bool] = None + logits_all: Optional[bool] = None + vocab_only: Optional[bool] = None + use_mmap: Optional[bool] = None + use_mlock: Optional[bool] = None + embedding_only: Optional[bool] = None + num_thread: Optional[int] = None + + # runtime options + num_keep: Optional[int] = None + seed: Optional[int] = None + num_predict: Optional[int] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + tfs_z: Optional[float] = None + typical_p: Optional[float] = None + repeat_last_n: Optional[int] = None + temperature: Optional[float] = None + repeat_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + mirostat: Optional[int] = None + mirostat_tau: Optional[float] = None + mirostat_eta: Optional[float] = None + penalize_newline: Optional[bool] = None + stop: Optional[Sequence[str]] = None + + +class BaseRequest(SubscriptableBaseModel): + model: Annotated[str, Field(min_length=1)] + 'Model to use for the request.' + + +class BaseStreamableRequest(BaseRequest): + stream: Optional[bool] = None + 'Stream response.' + + +class BaseGenerateRequest(BaseStreamableRequest): + options: Optional[Union[Mapping[str, Any], Options]] = None + 'Options to use for the request.' + + format: Optional[Literal['', 'json']] = None + 'Format of the response.' + + keep_alive: Optional[Union[float, str]] = None + 'Keep model alive for the specified duration.' + + +class Image(BaseModel): + value: Union[FilePath, Base64Str, bytes] -import sys + @model_serializer + def serialize_model(self): + if isinstance(self.value, Path): + return b64encode(self.value.read_bytes()).decode() + elif isinstance(self.value, bytes): + return b64encode(self.value).decode() + return self.value -if sys.version_info < (3, 11): - from typing_extensions import NotRequired -else: - from typing import NotRequired +class GenerateRequest(BaseGenerateRequest): + prompt: Optional[str] = None + 'Prompt to generate response from.' -class BaseGenerateResponse(TypedDict): - model: str + suffix: Optional[str] = None + 'Suffix to append to the response.' + + system: Optional[str] = None + 'System prompt to prepend to the prompt.' + + template: Optional[str] = None + 'Template to use for the response.' + + context: Optional[Sequence[int]] = None + 'Tokenized history to use for the response.' + + raw: Optional[bool] = None + + images: Optional[Sequence[Image]] = None + 'Image data for multimodal models.' + + +class BaseGenerateResponse(SubscriptableBaseModel): + model: Optional[str] = None 'Model used to generate response.' - created_at: str + created_at: Optional[str] = None 'Time when the request was created.' - done: bool + done: Optional[bool] = None 'True if response is complete, otherwise False. Useful for streaming to detect the final response.' - done_reason: str + done_reason: Optional[str] = None 'Reason for completion. Only present when done is True.' - total_duration: int + total_duration: Optional[int] = None 'Total duration in nanoseconds.' - load_duration: int + load_duration: Optional[int] = None 'Load duration in nanoseconds.' - prompt_eval_count: int + prompt_eval_count: Optional[int] = None 'Number of tokens evaluated in the prompt.' - prompt_eval_duration: int + prompt_eval_duration: Optional[int] = None 'Duration of evaluating the prompt in nanoseconds.' - eval_count: int + eval_count: Optional[int] = None 'Number of tokens evaluated in inference.' - eval_duration: int + eval_duration: Optional[int] = None 'Duration of evaluating inference in nanoseconds.' @@ -49,43 +168,22 @@ class GenerateResponse(BaseGenerateResponse): response: str 'Response content. When streaming, this contains a fragment of the response.' - context: Sequence[int] + context: Optional[Sequence[int]] = None 'Tokenized history up to the point of the response.' -class ToolCallFunction(TypedDict): - """ - Tool call function. - """ - - name: str - 'Name of the function.' - - arguments: NotRequired[Mapping[str, Any]] - 'Arguments of the function.' - - -class ToolCall(TypedDict): - """ - Model tool calls. - """ - - function: ToolCallFunction - 'Function to be called.' - - -class Message(TypedDict): +class Message(SubscriptableBaseModel): """ Chat message. """ role: Literal['user', 'assistant', 'system', 'tool'] - "Assumed role of the message. Response messages always has role 'assistant' or 'tool'." + "Assumed role of the message. Response messages has role 'assistant' or 'tool'." - content: NotRequired[str] + content: Optional[str] = None 'Content of the message. Response messages contains message fragments when streaming.' - images: NotRequired[Sequence[Any]] + images: Optional[Sequence[Image]] = None """ Optional list of image data for multimodal models. @@ -97,33 +195,54 @@ class Message(TypedDict): Valid image formats depend on the model. See the model card for more information. """ - tool_calls: NotRequired[Sequence[ToolCall]] + class ToolCall(SubscriptableBaseModel): + """ + Model tool calls. + """ + + class Function(SubscriptableBaseModel): + """ + Tool call function. + """ + + name: str + 'Name of the function.' + + arguments: Mapping[str, Any] + 'Arguments of the function.' + + function: Function + 'Function to be called.' + + tool_calls: Optional[Sequence[ToolCall]] = None """ Tools calls to be made by the model. """ -class Property(TypedDict): - type: str - description: str - enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings +class Tool(SubscriptableBaseModel): + type: Literal['function'] = 'function' + class Function(SubscriptableBaseModel): + name: str + description: str -class Parameters(TypedDict): - type: str - required: Sequence[str] - properties: Mapping[str, Property] + class Parameters(SubscriptableBaseModel): + type: str + required: Optional[Sequence[str]] = None + properties: Optional[JsonSchemaValue] = None + parameters: Parameters -class ToolFunction(TypedDict): - name: str - description: str - parameters: Parameters + function: Function -class Tool(TypedDict): - type: str - function: ToolFunction +class ChatRequest(BaseGenerateRequest): + messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None + 'Messages to chat with.' + + tools: Optional[Sequence[Tool]] = None + 'Tools to use for the chat.' class ChatResponse(BaseGenerateResponse): @@ -135,47 +254,156 @@ class ChatResponse(BaseGenerateResponse): 'Response message.' -class ProgressResponse(TypedDict): - status: str - completed: int - total: int - digest: str +class EmbedRequest(BaseRequest): + input: Union[str, Sequence[str]] + 'Input text to embed.' + truncate: Optional[bool] = None + 'Truncate the input to the maximum token length.' -class Options(TypedDict, total=False): - # load time options - numa: bool - num_ctx: int - num_batch: int - num_gpu: int - main_gpu: int - low_vram: bool - f16_kv: bool - logits_all: bool - vocab_only: bool - use_mmap: bool - use_mlock: bool - embedding_only: bool - num_thread: int + options: Optional[Union[Mapping[str, Any], Options]] = None + 'Options to use for the request.' - # runtime options - num_keep: int - seed: int - num_predict: int - top_k: int - top_p: float - tfs_z: float - typical_p: float - repeat_last_n: int - temperature: float - repeat_penalty: float - presence_penalty: float - frequency_penalty: float - mirostat: int - mirostat_tau: float - mirostat_eta: float - penalize_newline: bool - stop: Sequence[str] + keep_alive: Optional[Union[float, str]] = None + + +class EmbedResponse(BaseGenerateResponse): + """ + Response returned by embed requests. + """ + + embeddings: Sequence[Sequence[float]] + 'Embeddings of the inputs.' + + +class EmbeddingsRequest(BaseRequest): + prompt: Optional[str] = None + 'Prompt to generate embeddings from.' + + options: Optional[Union[Mapping[str, Any], Options]] = None + 'Options to use for the request.' + + keep_alive: Optional[Union[float, str]] = None + + +class EmbeddingsResponse(SubscriptableBaseModel): + """ + Response returned by embeddings requests. + """ + + embedding: Sequence[float] + 'Embedding of the prompt.' + + +class PullRequest(BaseStreamableRequest): + """ + Request to pull the model. + """ + + insecure: Optional[bool] = None + 'Allow insecure (HTTP) connections.' + + +class PushRequest(BaseStreamableRequest): + """ + Request to pull the model. + """ + + insecure: Optional[bool] = None + 'Allow insecure (HTTP) connections.' + + +class CreateRequest(BaseStreamableRequest): + """ + Request to create a new model. + """ + + modelfile: Optional[str] = None + + quantize: Optional[str] = None + + +class ModelDetails(SubscriptableBaseModel): + parent_model: Optional[str] = None + format: Optional[str] = None + family: Optional[str] = None + families: Optional[Sequence[str]] = None + parameter_size: Optional[str] = None + quantization_level: Optional[str] = None + + +class ListResponse(SubscriptableBaseModel): + class Model(BaseModel): + modified_at: Optional[datetime] = None + digest: Optional[str] = None + size: Optional[ByteSize] = None + details: Optional[ModelDetails] = None + + models: Sequence[Model] + 'List of models.' + + +class DeleteRequest(BaseRequest): + """ + Request to delete a model. + """ + + +class CopyRequest(BaseModel): + """ + Request to copy a model. + """ + + source: str + 'Source model to copy.' + + destination: str + 'Destination model to copy to.' + + +class StatusResponse(SubscriptableBaseModel): + status: Optional[str] = None + + +class ProgressResponse(StatusResponse): + completed: Optional[int] = None + total: Optional[int] = None + digest: Optional[str] = None + + +class ShowRequest(BaseRequest): + """ + Request to show model information. + """ + + +class ShowResponse(SubscriptableBaseModel): + modified_at: Optional[datetime] = None + + template: Optional[str] = None + + modelfile: Optional[str] = None + + license: Optional[str] = None + + details: Optional[ModelDetails] = None + + modelinfo: Optional[Mapping[str, Any]] = Field(alias='model_info') + + parameters: Optional[str] = None + + +class ProcessResponse(SubscriptableBaseModel): + class Model(BaseModel): + model: Optional[str] = None + name: Optional[str] = None + digest: Optional[str] = None + expires_at: Optional[datetime] = None + size: Optional[ByteSize] = None + size_vram: Optional[ByteSize] = None + details: Optional[ModelDetails] = None + + models: Sequence[Model] class RequestError(Exception): diff --git a/poetry.lock b/poetry.lock index 483a203..be1061d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,19 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + [[package]] name = "anyio" version = "4.3.0" @@ -395,6 +409,130 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pydantic" +version = "2.9.0" +description = "Data validation using Python type hints" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.9.0-py3-none-any.whl", hash = "sha256:f66a7073abd93214a20c5f7b32d56843137a7a2e70d02111f3be287035c45370"}, + {file = "pydantic-2.9.0.tar.gz", hash = "sha256:c7a8a9fdf7d100afa49647eae340e2d23efa382466a8d177efcd1381e9be5598"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.23.2" +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] +tzdata = {version = "*", markers = "python_version >= \"3.9\""} + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.23.2" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.23.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7d0324a35ab436c9d768753cbc3c47a865a2cbc0757066cb864747baa61f6ece"}, + {file = "pydantic_core-2.23.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:276ae78153a94b664e700ac362587c73b84399bd1145e135287513442e7dfbc7"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:964c7aa318da542cdcc60d4a648377ffe1a2ef0eb1e996026c7f74507b720a78"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1cf842265a3a820ebc6388b963ead065f5ce8f2068ac4e1c713ef77a67b71f7c"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae90b9e50fe1bd115b24785e962b51130340408156d34d67b5f8f3fa6540938e"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ae65fdfb8a841556b52935dfd4c3f79132dc5253b12c0061b96415208f4d622"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c8aa40f6ca803f95b1c1c5aeaee6237b9e879e4dfb46ad713229a63651a95fb"}, + {file = "pydantic_core-2.23.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c53100c8ee5a1e102766abde2158077d8c374bee0639201f11d3032e3555dfbc"}, + {file = "pydantic_core-2.23.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d6b9dd6aa03c812017411734e496c44fef29b43dba1e3dd1fa7361bbacfc1354"}, + {file = "pydantic_core-2.23.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b18cf68255a476b927910c6873d9ed00da692bb293c5b10b282bd48a0afe3ae2"}, + {file = "pydantic_core-2.23.2-cp310-none-win32.whl", hash = "sha256:e460475719721d59cd54a350c1f71c797c763212c836bf48585478c5514d2854"}, + {file = "pydantic_core-2.23.2-cp310-none-win_amd64.whl", hash = "sha256:5f3cf3721eaf8741cffaf092487f1ca80831202ce91672776b02b875580e174a"}, + {file = "pydantic_core-2.23.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:7ce8e26b86a91e305858e018afc7a6e932f17428b1eaa60154bd1f7ee888b5f8"}, + {file = "pydantic_core-2.23.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7e9b24cca4037a561422bf5dc52b38d390fb61f7bfff64053ce1b72f6938e6b2"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:753294d42fb072aa1775bfe1a2ba1012427376718fa4c72de52005a3d2a22178"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:257d6a410a0d8aeb50b4283dea39bb79b14303e0fab0f2b9d617701331ed1515"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c8319e0bd6a7b45ad76166cc3d5d6a36c97d0c82a196f478c3ee5346566eebfd"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7a05c0240f6c711eb381ac392de987ee974fa9336071fb697768dfdb151345ce"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d5b0ff3218858859910295df6953d7bafac3a48d5cd18f4e3ed9999efd2245f"}, + {file = "pydantic_core-2.23.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:96ef39add33ff58cd4c112cbac076726b96b98bb8f1e7f7595288dcfb2f10b57"}, + {file = "pydantic_core-2.23.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0102e49ac7d2df3379ef8d658d3bc59d3d769b0bdb17da189b75efa861fc07b4"}, + {file = "pydantic_core-2.23.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a6612c2a844043e4d10a8324c54cdff0042c558eef30bd705770793d70b224aa"}, + {file = "pydantic_core-2.23.2-cp311-none-win32.whl", hash = "sha256:caffda619099cfd4f63d48462f6aadbecee3ad9603b4b88b60cb821c1b258576"}, + {file = "pydantic_core-2.23.2-cp311-none-win_amd64.whl", hash = "sha256:6f80fba4af0cb1d2344869d56430e304a51396b70d46b91a55ed4959993c0589"}, + {file = "pydantic_core-2.23.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:4c83c64d05ffbbe12d4e8498ab72bdb05bcc1026340a4a597dc647a13c1605ec"}, + {file = "pydantic_core-2.23.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6294907eaaccf71c076abdd1c7954e272efa39bb043161b4b8aa1cd76a16ce43"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a801c5e1e13272e0909c520708122496647d1279d252c9e6e07dac216accc41"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cc0c316fba3ce72ac3ab7902a888b9dc4979162d320823679da270c2d9ad0cad"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b06c5d4e8701ac2ba99a2ef835e4e1b187d41095a9c619c5b185c9068ed2a49"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82764c0bd697159fe9947ad59b6db6d7329e88505c8f98990eb07e84cc0a5d81"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b1a195efd347ede8bcf723e932300292eb13a9d2a3c1f84eb8f37cbbc905b7f"}, + {file = "pydantic_core-2.23.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7efb12e5071ad8d5b547487bdad489fbd4a5a35a0fc36a1941517a6ad7f23e0"}, + {file = "pydantic_core-2.23.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5dd0ec5f514ed40e49bf961d49cf1bc2c72e9b50f29a163b2cc9030c6742aa73"}, + {file = "pydantic_core-2.23.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:820f6ee5c06bc868335e3b6e42d7ef41f50dfb3ea32fbd523ab679d10d8741c0"}, + {file = "pydantic_core-2.23.2-cp312-none-win32.whl", hash = "sha256:3713dc093d5048bfaedbba7a8dbc53e74c44a140d45ede020dc347dda18daf3f"}, + {file = "pydantic_core-2.23.2-cp312-none-win_amd64.whl", hash = "sha256:e1895e949f8849bc2757c0dbac28422a04be031204df46a56ab34bcf98507342"}, + {file = "pydantic_core-2.23.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:da43cbe593e3c87d07108d0ebd73771dc414488f1f91ed2e204b0370b94b37ac"}, + {file = "pydantic_core-2.23.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:64d094ea1aa97c6ded4748d40886076a931a8bf6f61b6e43e4a1041769c39dd2"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:084414ffe9a85a52940b49631321d636dadf3576c30259607b75516d131fecd0"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:043ef8469f72609c4c3a5e06a07a1f713d53df4d53112c6d49207c0bd3c3bd9b"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3649bd3ae6a8ebea7dc381afb7f3c6db237fc7cebd05c8ac36ca8a4187b03b30"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6db09153d8438425e98cdc9a289c5fade04a5d2128faff8f227c459da21b9703"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5668b3173bb0b2e65020b60d83f5910a7224027232c9f5dc05a71a1deac9f960"}, + {file = "pydantic_core-2.23.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c7b81beaf7c7ebde978377dc53679c6cba0e946426fc7ade54251dfe24a7604"}, + {file = "pydantic_core-2.23.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:ae579143826c6f05a361d9546446c432a165ecf1c0b720bbfd81152645cb897d"}, + {file = "pydantic_core-2.23.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:19f1352fe4b248cae22a89268720fc74e83f008057a652894f08fa931e77dced"}, + {file = "pydantic_core-2.23.2-cp313-none-win32.whl", hash = "sha256:e1a79ad49f346aa1a2921f31e8dbbab4d64484823e813a002679eaa46cba39e1"}, + {file = "pydantic_core-2.23.2-cp313-none-win_amd64.whl", hash = "sha256:582871902e1902b3c8e9b2c347f32a792a07094110c1bca6c2ea89b90150caac"}, + {file = "pydantic_core-2.23.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:743e5811b0c377eb830150d675b0847a74a44d4ad5ab8845923d5b3a756d8100"}, + {file = "pydantic_core-2.23.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6650a7bbe17a2717167e3e23c186849bae5cef35d38949549f1c116031b2b3aa"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56e6a12ec8d7679f41b3750ffa426d22b44ef97be226a9bab00a03365f217b2b"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:810ca06cca91de9107718dc83d9ac4d2e86efd6c02cba49a190abcaf33fb0472"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:785e7f517ebb9890813d31cb5d328fa5eda825bb205065cde760b3150e4de1f7"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ef71ec876fcc4d3bbf2ae81961959e8d62f8d74a83d116668409c224012e3af"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d50ac34835c6a4a0d456b5db559b82047403c4317b3bc73b3455fefdbdc54b0a"}, + {file = "pydantic_core-2.23.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:16b25a4a120a2bb7dab51b81e3d9f3cde4f9a4456566c403ed29ac81bf49744f"}, + {file = "pydantic_core-2.23.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:41ae8537ad371ec018e3c5da0eb3f3e40ee1011eb9be1da7f965357c4623c501"}, + {file = "pydantic_core-2.23.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07049ec9306ec64e955b2e7c40c8d77dd78ea89adb97a2013d0b6e055c5ee4c5"}, + {file = "pydantic_core-2.23.2-cp38-none-win32.whl", hash = "sha256:086c5db95157dc84c63ff9d96ebb8856f47ce113c86b61065a066f8efbe80acf"}, + {file = "pydantic_core-2.23.2-cp38-none-win_amd64.whl", hash = "sha256:67b6655311b00581914aba481729971b88bb8bc7996206590700a3ac85e457b8"}, + {file = "pydantic_core-2.23.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:358331e21a897151e54d58e08d0219acf98ebb14c567267a87e971f3d2a3be59"}, + {file = "pydantic_core-2.23.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c4d9f15ffe68bcd3898b0ad7233af01b15c57d91cd1667f8d868e0eacbfe3f87"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0123655fedacf035ab10c23450163c2f65a4174f2bb034b188240a6cf06bb123"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e6e3ccebdbd6e53474b0bb7ab8b88e83c0cfe91484b25e058e581348ee5a01a5"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc535cb898ef88333cf317777ecdfe0faac1c2a3187ef7eb061b6f7ecf7e6bae"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aab9e522efff3993a9e98ab14263d4e20211e62da088298089a03056980a3e69"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05b366fb8fe3d8683b11ac35fa08947d7b92be78ec64e3277d03bd7f9b7cda79"}, + {file = "pydantic_core-2.23.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7568f682c06f10f30ef643a1e8eec4afeecdafde5c4af1b574c6df079e96f96c"}, + {file = "pydantic_core-2.23.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cdd02a08205dc90238669f082747612cb3c82bd2c717adc60f9b9ecadb540f80"}, + {file = "pydantic_core-2.23.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1a2ab4f410f4b886de53b6bddf5dd6f337915a29dd9f22f20f3099659536b2f6"}, + {file = "pydantic_core-2.23.2-cp39-none-win32.whl", hash = "sha256:0448b81c3dfcde439551bb04a9f41d7627f676b12701865c8a2574bcea034437"}, + {file = "pydantic_core-2.23.2-cp39-none-win_amd64.whl", hash = "sha256:4cebb9794f67266d65e7e4cbe5dcf063e29fc7b81c79dc9475bd476d9534150e"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e758d271ed0286d146cf7c04c539a5169a888dd0b57026be621547e756af55bc"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f477d26183e94eaafc60b983ab25af2a809a1b48ce4debb57b343f671b7a90b6"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da3131ef2b940b99106f29dfbc30d9505643f766704e14c5d5e504e6a480c35e"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:329a721253c7e4cbd7aad4a377745fbcc0607f9d72a3cc2102dd40519be75ed2"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7706e15cdbf42f8fab1e6425247dfa98f4a6f8c63746c995d6a2017f78e619ae"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e64ffaf8f6e17ca15eb48344d86a7a741454526f3a3fa56bc493ad9d7ec63936"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dd59638025160056687d598b054b64a79183f8065eae0d3f5ca523cde9943940"}, + {file = "pydantic_core-2.23.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:12625e69b1199e94b0ae1c9a95d000484ce9f0182f9965a26572f054b1537e44"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5d813fd871b3d5c3005157622ee102e8908ad6011ec915a18bd8fde673c4360e"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1eb37f7d6a8001c0f86dc8ff2ee8d08291a536d76e49e78cda8587bb54d8b329"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ce7eaf9a98680b4312b7cebcdd9352531c43db00fca586115845df388f3c465"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f087879f1ffde024dd2788a30d55acd67959dcf6c431e9d3682d1c491a0eb474"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ce883906810b4c3bd90e0ada1f9e808d9ecf1c5f0b60c6b8831d6100bcc7dd6"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8031074a397a5925d06b590121f8339d34a5a74cfe6970f8a1124eb8b83f4ac"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:23af245b8f2f4ee9e2c99cb3f93d0e22fb5c16df3f2f643f5a8da5caff12a653"}, + {file = "pydantic_core-2.23.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c57e493a0faea1e4c38f860d6862ba6832723396c884fbf938ff5e9b224200e2"}, + {file = "pydantic_core-2.23.2.tar.gz", hash = "sha256:95d6bf449a1ac81de562d65d180af5d8c19672793c81877a2eda8fde5d08f2fd"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pytest" version = "8.3.2" @@ -527,6 +665,28 @@ files = [ {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "werkzeug" version = "3.0.1" @@ -547,4 +707,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "e36516c932ab9dd7497acc0c3d55ab2c963004595efe97c2bc80854687c32c1e" +content-hash = "e664c86cc330480eb86239842f55f12b0fba4df5c2fc776d094f37f58320e637" diff --git a/pyproject.toml b/pyproject.toml index 3adf10f..332a3cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ repository = "https://github.com/jmorganca/ollama-python" [tool.poetry.dependencies] python = "^3.8" httpx = "^0.27.0" +pydantic = "^2.9.0" [tool.poetry.group.dev.dependencies] pytest = ">=7.4.3,<9.0.0" diff --git a/requirements.txt b/requirements.txt index f1dde1f..f065f01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +annotated-types==0.7.0 ; python_version >= "3.8" and python_version < "4.0" \ + --hash=sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53 \ + --hash=sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89 anyio==4.3.0 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8 \ --hash=sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6 @@ -19,9 +22,105 @@ httpx==0.27.0 ; python_version >= "3.8" and python_version < "4.0" \ idna==3.6 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca \ --hash=sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f +pydantic-core==2.23.2 ; python_version >= "3.8" and python_version < "4.0" \ + --hash=sha256:0102e49ac7d2df3379ef8d658d3bc59d3d769b0bdb17da189b75efa861fc07b4 \ + --hash=sha256:0123655fedacf035ab10c23450163c2f65a4174f2bb034b188240a6cf06bb123 \ + --hash=sha256:043ef8469f72609c4c3a5e06a07a1f713d53df4d53112c6d49207c0bd3c3bd9b \ + --hash=sha256:0448b81c3dfcde439551bb04a9f41d7627f676b12701865c8a2574bcea034437 \ + --hash=sha256:05b366fb8fe3d8683b11ac35fa08947d7b92be78ec64e3277d03bd7f9b7cda79 \ + --hash=sha256:07049ec9306ec64e955b2e7c40c8d77dd78ea89adb97a2013d0b6e055c5ee4c5 \ + --hash=sha256:084414ffe9a85a52940b49631321d636dadf3576c30259607b75516d131fecd0 \ + --hash=sha256:086c5db95157dc84c63ff9d96ebb8856f47ce113c86b61065a066f8efbe80acf \ + --hash=sha256:12625e69b1199e94b0ae1c9a95d000484ce9f0182f9965a26572f054b1537e44 \ + --hash=sha256:16b25a4a120a2bb7dab51b81e3d9f3cde4f9a4456566c403ed29ac81bf49744f \ + --hash=sha256:19f1352fe4b248cae22a89268720fc74e83f008057a652894f08fa931e77dced \ + --hash=sha256:1a2ab4f410f4b886de53b6bddf5dd6f337915a29dd9f22f20f3099659536b2f6 \ + --hash=sha256:1c7b81beaf7c7ebde978377dc53679c6cba0e946426fc7ade54251dfe24a7604 \ + --hash=sha256:1cf842265a3a820ebc6388b963ead065f5ce8f2068ac4e1c713ef77a67b71f7c \ + --hash=sha256:1eb37f7d6a8001c0f86dc8ff2ee8d08291a536d76e49e78cda8587bb54d8b329 \ + --hash=sha256:23af245b8f2f4ee9e2c99cb3f93d0e22fb5c16df3f2f643f5a8da5caff12a653 \ + --hash=sha256:257d6a410a0d8aeb50b4283dea39bb79b14303e0fab0f2b9d617701331ed1515 \ + --hash=sha256:276ae78153a94b664e700ac362587c73b84399bd1145e135287513442e7dfbc7 \ + --hash=sha256:2b1a195efd347ede8bcf723e932300292eb13a9d2a3c1f84eb8f37cbbc905b7f \ + --hash=sha256:329a721253c7e4cbd7aad4a377745fbcc0607f9d72a3cc2102dd40519be75ed2 \ + --hash=sha256:358331e21a897151e54d58e08d0219acf98ebb14c567267a87e971f3d2a3be59 \ + --hash=sha256:3649bd3ae6a8ebea7dc381afb7f3c6db237fc7cebd05c8ac36ca8a4187b03b30 \ + --hash=sha256:3713dc093d5048bfaedbba7a8dbc53e74c44a140d45ede020dc347dda18daf3f \ + --hash=sha256:3ef71ec876fcc4d3bbf2ae81961959e8d62f8d74a83d116668409c224012e3af \ + --hash=sha256:41ae8537ad371ec018e3c5da0eb3f3e40ee1011eb9be1da7f965357c4623c501 \ + --hash=sha256:4a801c5e1e13272e0909c520708122496647d1279d252c9e6e07dac216accc41 \ + --hash=sha256:4c83c64d05ffbbe12d4e8498ab72bdb05bcc1026340a4a597dc647a13c1605ec \ + --hash=sha256:4cebb9794f67266d65e7e4cbe5dcf063e29fc7b81c79dc9475bd476d9534150e \ + --hash=sha256:5668b3173bb0b2e65020b60d83f5910a7224027232c9f5dc05a71a1deac9f960 \ + --hash=sha256:56e6a12ec8d7679f41b3750ffa426d22b44ef97be226a9bab00a03365f217b2b \ + --hash=sha256:582871902e1902b3c8e9b2c347f32a792a07094110c1bca6c2ea89b90150caac \ + --hash=sha256:5c8aa40f6ca803f95b1c1c5aeaee6237b9e879e4dfb46ad713229a63651a95fb \ + --hash=sha256:5d813fd871b3d5c3005157622ee102e8908ad6011ec915a18bd8fde673c4360e \ + --hash=sha256:5dd0ec5f514ed40e49bf961d49cf1bc2c72e9b50f29a163b2cc9030c6742aa73 \ + --hash=sha256:5f3cf3721eaf8741cffaf092487f1ca80831202ce91672776b02b875580e174a \ + --hash=sha256:6294907eaaccf71c076abdd1c7954e272efa39bb043161b4b8aa1cd76a16ce43 \ + --hash=sha256:64d094ea1aa97c6ded4748d40886076a931a8bf6f61b6e43e4a1041769c39dd2 \ + --hash=sha256:6650a7bbe17a2717167e3e23c186849bae5cef35d38949549f1c116031b2b3aa \ + --hash=sha256:67b6655311b00581914aba481729971b88bb8bc7996206590700a3ac85e457b8 \ + --hash=sha256:6b06c5d4e8701ac2ba99a2ef835e4e1b187d41095a9c619c5b185c9068ed2a49 \ + --hash=sha256:6ce883906810b4c3bd90e0ada1f9e808d9ecf1c5f0b60c6b8831d6100bcc7dd6 \ + --hash=sha256:6db09153d8438425e98cdc9a289c5fade04a5d2128faff8f227c459da21b9703 \ + --hash=sha256:6f80fba4af0cb1d2344869d56430e304a51396b70d46b91a55ed4959993c0589 \ + --hash=sha256:743e5811b0c377eb830150d675b0847a74a44d4ad5ab8845923d5b3a756d8100 \ + --hash=sha256:753294d42fb072aa1775bfe1a2ba1012427376718fa4c72de52005a3d2a22178 \ + --hash=sha256:7568f682c06f10f30ef643a1e8eec4afeecdafde5c4af1b574c6df079e96f96c \ + --hash=sha256:7706e15cdbf42f8fab1e6425247dfa98f4a6f8c63746c995d6a2017f78e619ae \ + --hash=sha256:785e7f517ebb9890813d31cb5d328fa5eda825bb205065cde760b3150e4de1f7 \ + --hash=sha256:7a05c0240f6c711eb381ac392de987ee974fa9336071fb697768dfdb151345ce \ + --hash=sha256:7ce7eaf9a98680b4312b7cebcdd9352531c43db00fca586115845df388f3c465 \ + --hash=sha256:7ce8e26b86a91e305858e018afc7a6e932f17428b1eaa60154bd1f7ee888b5f8 \ + --hash=sha256:7d0324a35ab436c9d768753cbc3c47a865a2cbc0757066cb864747baa61f6ece \ + --hash=sha256:7e9b24cca4037a561422bf5dc52b38d390fb61f7bfff64053ce1b72f6938e6b2 \ + --hash=sha256:810ca06cca91de9107718dc83d9ac4d2e86efd6c02cba49a190abcaf33fb0472 \ + --hash=sha256:820f6ee5c06bc868335e3b6e42d7ef41f50dfb3ea32fbd523ab679d10d8741c0 \ + --hash=sha256:82764c0bd697159fe9947ad59b6db6d7329e88505c8f98990eb07e84cc0a5d81 \ + --hash=sha256:8ae65fdfb8a841556b52935dfd4c3f79132dc5253b12c0061b96415208f4d622 \ + --hash=sha256:8d5b0ff3218858859910295df6953d7bafac3a48d5cd18f4e3ed9999efd2245f \ + --hash=sha256:95d6bf449a1ac81de562d65d180af5d8c19672793c81877a2eda8fde5d08f2fd \ + --hash=sha256:964c7aa318da542cdcc60d4a648377ffe1a2ef0eb1e996026c7f74507b720a78 \ + --hash=sha256:96ef39add33ff58cd4c112cbac076726b96b98bb8f1e7f7595288dcfb2f10b57 \ + --hash=sha256:a6612c2a844043e4d10a8324c54cdff0042c558eef30bd705770793d70b224aa \ + --hash=sha256:a8031074a397a5925d06b590121f8339d34a5a74cfe6970f8a1124eb8b83f4ac \ + --hash=sha256:aab9e522efff3993a9e98ab14263d4e20211e62da088298089a03056980a3e69 \ + --hash=sha256:ae579143826c6f05a361d9546446c432a165ecf1c0b720bbfd81152645cb897d \ + --hash=sha256:ae90b9e50fe1bd115b24785e962b51130340408156d34d67b5f8f3fa6540938e \ + --hash=sha256:b18cf68255a476b927910c6873d9ed00da692bb293c5b10b282bd48a0afe3ae2 \ + --hash=sha256:b7efb12e5071ad8d5b547487bdad489fbd4a5a35a0fc36a1941517a6ad7f23e0 \ + --hash=sha256:c4d9f15ffe68bcd3898b0ad7233af01b15c57d91cd1667f8d868e0eacbfe3f87 \ + --hash=sha256:c53100c8ee5a1e102766abde2158077d8c374bee0639201f11d3032e3555dfbc \ + --hash=sha256:c57e493a0faea1e4c38f860d6862ba6832723396c884fbf938ff5e9b224200e2 \ + --hash=sha256:c8319e0bd6a7b45ad76166cc3d5d6a36c97d0c82a196f478c3ee5346566eebfd \ + --hash=sha256:caffda619099cfd4f63d48462f6aadbecee3ad9603b4b88b60cb821c1b258576 \ + --hash=sha256:cc0c316fba3ce72ac3ab7902a888b9dc4979162d320823679da270c2d9ad0cad \ + --hash=sha256:cdd02a08205dc90238669f082747612cb3c82bd2c717adc60f9b9ecadb540f80 \ + --hash=sha256:d50ac34835c6a4a0d456b5db559b82047403c4317b3bc73b3455fefdbdc54b0a \ + --hash=sha256:d6b9dd6aa03c812017411734e496c44fef29b43dba1e3dd1fa7361bbacfc1354 \ + --hash=sha256:da3131ef2b940b99106f29dfbc30d9505643f766704e14c5d5e504e6a480c35e \ + --hash=sha256:da43cbe593e3c87d07108d0ebd73771dc414488f1f91ed2e204b0370b94b37ac \ + --hash=sha256:dd59638025160056687d598b054b64a79183f8065eae0d3f5ca523cde9943940 \ + --hash=sha256:e1895e949f8849bc2757c0dbac28422a04be031204df46a56ab34bcf98507342 \ + --hash=sha256:e1a79ad49f346aa1a2921f31e8dbbab4d64484823e813a002679eaa46cba39e1 \ + --hash=sha256:e460475719721d59cd54a350c1f71c797c763212c836bf48585478c5514d2854 \ + --hash=sha256:e64ffaf8f6e17ca15eb48344d86a7a741454526f3a3fa56bc493ad9d7ec63936 \ + --hash=sha256:e6e3ccebdbd6e53474b0bb7ab8b88e83c0cfe91484b25e058e581348ee5a01a5 \ + --hash=sha256:e758d271ed0286d146cf7c04c539a5169a888dd0b57026be621547e756af55bc \ + --hash=sha256:f087879f1ffde024dd2788a30d55acd67959dcf6c431e9d3682d1c491a0eb474 \ + --hash=sha256:f477d26183e94eaafc60b983ab25af2a809a1b48ce4debb57b343f671b7a90b6 \ + --hash=sha256:fc535cb898ef88333cf317777ecdfe0faac1c2a3187ef7eb061b6f7ecf7e6bae +pydantic==2.9.0 ; python_version >= "3.8" and python_version < "4.0" \ + --hash=sha256:c7a8a9fdf7d100afa49647eae340e2d23efa382466a8d177efcd1381e9be5598 \ + --hash=sha256:f66a7073abd93214a20c5f7b32d56843137a7a2e70d02111f3be287035c45370 sniffio==1.3.1 ; python_version >= "3.8" and python_version < "4.0" \ --hash=sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2 \ --hash=sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc -typing-extensions==4.10.0 ; python_version >= "3.8" and python_version < "3.11" \ - --hash=sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475 \ - --hash=sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb +typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "4.0" \ + --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ + --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 +tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0" \ + --hash=sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd \ + --hash=sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252 diff --git a/tests/test_client.py b/tests/test_client.py index efc8d4f..3bb451c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -28,9 +28,6 @@ def test_client_chat(httpserver: HTTPServer): 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'tools': [], 'stream': False, - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_json( { @@ -76,9 +73,6 @@ def generate(): 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'tools': [], 'stream': True, - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -106,9 +100,6 @@ def test_client_chat_images(httpserver: HTTPServer): ], 'tools': [], 'stream': False, - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_json( { @@ -137,16 +128,7 @@ def test_client_generate(httpserver: HTTPServer): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': False, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_json( { @@ -183,16 +165,7 @@ def generate(): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': True, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -212,16 +185,8 @@ def test_client_generate_images(httpserver: HTTPServer): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': False, - 'raw': False, 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_json( { @@ -244,15 +209,11 @@ def test_client_pull(httpserver: HTTPServer): '/api/pull', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': False, }, - ).respond_with_json( - { - 'status': 'success', - } - ) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) response = client.pull('dummy') @@ -274,7 +235,7 @@ def generate(): '/api/pull', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': True, }, @@ -293,15 +254,15 @@ def test_client_push(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': False, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) response = client.push('dummy') - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_push_stream(httpserver: HTTPServer): @@ -317,7 +278,7 @@ def generate(): '/api/push', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': True, }, @@ -337,12 +298,11 @@ def test_client_create_path(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) @@ -352,7 +312,7 @@ def test_client_create_path(httpserver: HTTPServer): modelfile.flush() response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_create_path_relative(httpserver: HTTPServer): @@ -361,12 +321,11 @@ def test_client_create_path_relative(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) @@ -376,7 +335,7 @@ def test_client_create_path_relative(httpserver: HTTPServer): modelfile.flush() response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.fixture @@ -394,12 +353,11 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) @@ -409,7 +367,7 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir): modelfile.flush() response = client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_create_modelfile(httpserver: HTTPServer): @@ -418,18 +376,17 @@ def test_client_create_modelfile(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) with tempfile.NamedTemporaryFile() as blob: response = client.create('dummy', modelfile=f'FROM {blob.name}') - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): @@ -438,7 +395,7 @@ def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 TEMPLATE """[INST] <>{{.System}}<> {{.Prompt}} [/INST]""" @@ -452,9 +409,8 @@ def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): PARAMETER stop <> PARAMETER stop <>''', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) @@ -478,7 +434,7 @@ def test_client_create_modelfile_roundtrip(httpserver: HTTPServer): ] ), ) - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_create_from_library(httpserver: HTTPServer): @@ -486,17 +442,16 @@ def test_client_create_from_library(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM llama2', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = Client(httpserver.url_for('/')) response = client.create('dummy', modelfile='FROM llama2') - assert isinstance(response, dict) + assert response['status'] == 'success' def test_client_create_blob(httpserver: HTTPServer): @@ -524,14 +479,14 @@ def test_client_delete(httpserver: HTTPServer): httpserver.expect_ordered_request(PrefixPattern('/api/delete'), method='DELETE').respond_with_response(Response(status=200)) client = Client(httpserver.url_for('/api/delete')) response = client.delete('dummy') - assert response == {'status': 'success'} + assert response['status'] == 'success' def test_client_copy(httpserver: HTTPServer): httpserver.expect_ordered_request(PrefixPattern('/api/copy'), method='POST').respond_with_response(Response(status=200)) client = Client(httpserver.url_for('/api/copy')) response = client.copy('dum', 'dummer') - assert response == {'status': 'success'} + assert response['status'] == 'success' @pytest.mark.asyncio @@ -544,15 +499,22 @@ async def test_async_client_chat(httpserver: HTTPServer): 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'tools': [], 'stream': False, - 'format': '', - 'options': {}, - 'keep_alive': None, }, - ).respond_with_json({}) + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) client = AsyncClient(httpserver.url_for('/')) response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) - assert isinstance(response, dict) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." @pytest.mark.asyncio @@ -583,9 +545,6 @@ def generate(): 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], 'tools': [], 'stream': True, - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -614,18 +573,25 @@ async def test_async_client_chat_images(httpserver: HTTPServer): ], 'tools': [], 'stream': False, - 'format': '', - 'options': {}, - 'keep_alive': None, }, - ).respond_with_json({}) + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': "I don't know.", + }, + } + ) client = AsyncClient(httpserver.url_for('/')) with io.BytesIO() as b: Image.new('RGB', (1, 1)).save(b, 'PNG') response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}]) - assert isinstance(response, dict) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == "I don't know." @pytest.mark.asyncio @@ -636,22 +602,19 @@ async def test_async_client_generate(httpserver: HTTPServer): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': False, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - 'keep_alive': None, }, - ).respond_with_json({}) + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Because it is.', + } + ) client = AsyncClient(httpserver.url_for('/')) response = await client.generate('dummy', 'Why is the sky blue?') - assert isinstance(response, dict) + assert response['model'] == 'dummy' + assert response['response'] == 'Because it is.' @pytest.mark.asyncio @@ -677,16 +640,7 @@ def generate(): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': True, - 'raw': False, - 'images': [], - 'format': '', - 'options': {}, - 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -707,25 +661,23 @@ async def test_async_client_generate_images(httpserver: HTTPServer): json={ 'model': 'dummy', 'prompt': 'Why is the sky blue?', - 'suffix': '', - 'system': '', - 'template': '', - 'context': [], 'stream': False, - 'raw': False, 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], - 'format': '', - 'options': {}, - 'keep_alive': None, }, - ).respond_with_json({}) + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Because it is.', + } + ) client = AsyncClient(httpserver.url_for('/')) with tempfile.NamedTemporaryFile() as temp: Image.new('RGB', (1, 1)).save(temp, 'PNG') response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name]) - assert isinstance(response, dict) + assert response['model'] == 'dummy' + assert response['response'] == 'Because it is.' @pytest.mark.asyncio @@ -734,15 +686,15 @@ async def test_async_client_pull(httpserver: HTTPServer): '/api/pull', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': False, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) response = await client.pull('dummy') - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -761,7 +713,7 @@ def generate(): '/api/pull', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': True, }, @@ -781,15 +733,15 @@ async def test_async_client_push(httpserver: HTTPServer): '/api/push', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': False, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) response = await client.push('dummy') - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -806,7 +758,7 @@ def generate(): '/api/push', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'insecure': False, 'stream': True, }, @@ -827,12 +779,11 @@ async def test_async_client_create_path(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) @@ -842,7 +793,7 @@ async def test_async_client_create_path(httpserver: HTTPServer): modelfile.flush() response = await client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -852,12 +803,11 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) @@ -867,7 +817,7 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer): modelfile.flush() response = await client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -877,12 +827,11 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) @@ -892,7 +841,7 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho modelfile.flush() response = await client.create('dummy', path=modelfile.name) - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -902,18 +851,17 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) with tempfile.NamedTemporaryFile() as blob: response = await client.create('dummy', modelfile=f'FROM {blob.name}') - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -923,7 +871,7 @@ async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 TEMPLATE """[INST] <>{{.System}}<> {{.Prompt}} [/INST]""" @@ -937,9 +885,8 @@ async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): PARAMETER stop <> PARAMETER stop <>''', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) @@ -963,7 +910,7 @@ async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer): ] ), ) - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -972,17 +919,16 @@ async def test_async_client_create_from_library(httpserver: HTTPServer): '/api/create', method='POST', json={ - 'name': 'dummy', + 'model': 'dummy', 'modelfile': 'FROM llama2', 'stream': False, - 'quantize': None, }, - ).respond_with_json({}) + ).respond_with_json({'status': 'success'}) client = AsyncClient(httpserver.url_for('/')) response = await client.create('dummy', modelfile='FROM llama2') - assert isinstance(response, dict) + assert response['status'] == 'success' @pytest.mark.asyncio @@ -1013,7 +959,7 @@ async def test_async_client_delete(httpserver: HTTPServer): httpserver.expect_ordered_request(PrefixPattern('/api/delete'), method='DELETE').respond_with_response(Response(status=200)) client = AsyncClient(httpserver.url_for('/api/delete')) response = await client.delete('dummy') - assert response == {'status': 'success'} + assert response['status'] == 'success' @pytest.mark.asyncio @@ -1021,4 +967,4 @@ async def test_async_client_copy(httpserver: HTTPServer): httpserver.expect_ordered_request(PrefixPattern('/api/copy'), method='POST').respond_with_response(Response(status=200)) client = AsyncClient(httpserver.url_for('/api/copy')) response = await client.copy('dum', 'dummer') - assert response == {'status': 'success'} + assert response['status'] == 'success'