diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index 68c4d39..25fd1c3 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -6,8 +6,13 @@ import json import yaml -import openai -from openai import error +from openai import AsyncOpenAI, AsyncAzureOpenAI +from openai.types.chat.chat_completion import ( + Choice, + ChatCompletion, + ChatCompletionMessage, +) +from openai._exceptions import OpenAIError, AuthenticationError from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry @@ -41,7 +46,6 @@ CONF_BASE_URL, CONF_API_VERSION, CONF_SKIP_AUTHENTICATION, - CONF_MODEL_KEY, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -77,8 +81,7 @@ convert_to_template, validate_authentication, get_function_executor, - get_api_type, - get_default_model_key, + is_azure, ) @@ -105,10 +108,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION ), ) - except error.AuthenticationError as err: + except AuthenticationError as err: _LOGGER.error("Invalid API key: %s", err) return False - except error.OpenAIError as err: + except OpenAIError as err: raise ConfigEntryNotReady(err) from err agent = OpenAIAgent(hass, entry) @@ -136,6 +139,11 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: self.hass = hass self.entry = entry self.history: dict[str, list[dict]] = {} + base_url = entry.data.get(CONF_BASE_URL) + if is_azure(base_url): + self.client = AsyncAzureOpenAI(api_key=entry.data[CONF_API_KEY], azure_endpoint=base_url, api_version=entry.data.get(CONF_API_VERSION)) + else: + self.client = AsyncOpenAI(api_key=entry.data[CONF_API_KEY], base_url=base_url) @property def supported_languages(self) -> list[str] | Literal["*"]: @@ -177,7 +185,7 @@ async def async_process( try: response = await self.query(user_input, messages, exposed_entities, 0) - except error.OpenAIError as err: + except OpenAIError as err: _LOGGER.error(err) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( @@ -198,11 +206,11 @@ async def async_process( response=intent_response, conversation_id=conversation_id ) - messages.append(response) + messages.append(response.model_dump(exclude_none=True)) self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response["content"]) + intent_response.async_set_speech(response.content) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) @@ -269,10 +277,6 @@ async def query( n_requests, ): """Process a sentence.""" - api_base = self.entry.data.get(CONF_BASE_URL) - api_key = self.entry.data[CONF_API_KEY] - api_type = get_api_type(api_base) - api_version = self.entry.data.get(CONF_API_VERSION) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) @@ -288,18 +292,10 @@ async def query( functions = None function_call = None - model_key = self.entry.options.get( - CONF_MODEL_KEY, get_default_model_key(api_base) - ) - model_kwargs = {model_key: model} - _LOGGER.info("Prompt for %s: %s", model, messages) - response = await openai.ChatCompletion.acreate( - api_base=api_base, - api_key=api_key, - api_type=api_type, - api_version=api_version, + response: ChatCompletion = await self.client.chat.completions.create( + model=model, messages=messages, max_tokens=max_tokens, top_p=top_p, @@ -307,12 +303,13 @@ async def query( user=user_input.conversation_id, functions=functions, function_call=function_call, - **model_kwargs, ) + _LOGGER.info("Response %s", response) - message = response["choices"][0]["message"] - if message.get("function_call"): + choice: Choice = response.choices[0] + message = choice.message + if choice.finish_reason == "function_call": message = await self.execute_function_call( user_input, messages, message, exposed_entities, n_requests + 1 ) @@ -322,11 +319,11 @@ def execute_function_call( self, user_input: conversation.ConversationInput, messages, - message, + message: ChatCompletionMessage, exposed_entities, n_requests, ): - function_name = message["function_call"]["name"] + function_name = message.function_call.name function = next( (s for s in self.get_functions() if s["spec"]["name"] == function_name), None, @@ -340,13 +337,13 @@ def execute_function_call( n_requests, function, ) - raise FunctionNotFound(message["function_call"]["name"]) + raise FunctionNotFound(function_name) async def execute_function( self, user_input: conversation.ConversationInput, messages, - message, + message: ChatCompletionMessage, exposed_entities, n_requests, function, @@ -354,9 +351,9 @@ async def execute_function( function_executor = get_function_executor(function["function"]["type"]) try: - arguments = json.loads(message["function_call"]["arguments"]) + arguments = json.loads(message.function_call.arguments) except json.decoder.JSONDecodeError as err: - raise ParseArgumentsFailed(message["function_call"]["arguments"]) from err + raise ParseArgumentsFailed(message.function_call.arguments) from err result = await function_executor.execute( self.hass, function["function"], arguments, user_input, exposed_entities @@ -365,7 +362,7 @@ async def execute_function( messages.append( { "role": "function", - "name": message["function_call"]["name"], + "name": message.function_call.name, "content": str(result), } ) diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index 168a06c..4692af8 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -7,7 +7,7 @@ from types import MappingProxyType from typing import Any -from openai import error +from openai._exceptions import APIConnectionError, AuthenticationError import voluptuous as vol from homeassistant import config_entries @@ -26,7 +26,7 @@ SelectSelectorMode, ) -from .helpers import validate_authentication, get_default_model_key +from .helpers import validate_authentication from .const import ( CONF_ATTACH_USERNAME, @@ -40,8 +40,6 @@ CONF_BASE_URL, CONF_API_VERSION, CONF_SKIP_AUTHENTICATION, - CONF_MODEL_KEY, - MODEL_KEYS, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -128,9 +126,9 @@ async def async_step_user( try: await validate_input(self.hass, user_input) - except error.APIConnectionError: + except APIConnectionError: errors["base"] = "cannot_connect" - except error.AuthenticationError: + except AuthenticationError: errors["base"] = "invalid_auth" except Exception: # pylint: disable=broad-except _LOGGER.exception("Unexpected exception") @@ -224,18 +222,4 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di description={"suggested_value": options.get(CONF_ATTACH_USERNAME)}, default=DEFAULT_ATTACH_USERNAME, ): BooleanSelector(), - vol.Optional( - CONF_MODEL_KEY, - description={"suggested_value": options.get(CONF_MODEL_KEY)}, - default=get_default_model_key( - self.config_entry.data.get(CONF_BASE_URL) - ), - ): SelectSelector( - SelectSelectorConfig( - options=[ - SelectOptionDict(value=key, label=key) for key in MODEL_KEYS - ], - mode=SelectSelectorMode.DROPDOWN, - ) - ), } diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index e607769..25768ed 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -84,5 +84,3 @@ ] CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False -CONF_MODEL_KEY = "model_key" -MODEL_KEYS = ["model", "engine"] diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 2a43d9d..6012316 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -4,10 +4,9 @@ import yaml import time import sqlite3 -import openai +from openai import AsyncOpenAI, AsyncAzureOpenAI import re import voluptuous as vol -from functools import partial from bs4 import BeautifulSoup from typing import Any from urllib import parse @@ -77,17 +76,10 @@ def get_function_executor(value: str): return function_executor -def get_api_type(base_url: str): +def is_azure(base_url: str): if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url): - return "azure" - return None - - -def get_default_model_key(base_url: str): - is_azure = get_api_type(base_url) == "azure" - if is_azure: - return "engine" - return "model" + return True + return False def convert_to_template( @@ -151,16 +143,12 @@ async def validate_authentication( if skip_authentication: return - await hass.async_add_executor_job( - partial( - openai.Model.list, - api_type=get_api_type(base_url), - api_key=api_key, - api_version=api_version, - api_base=base_url, - request_timeout=10, - ) - ) + if is_azure(base_url): + client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + else: + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + await client.models.list(timeout=10) class FunctionExecutor(ABC): diff --git a/custom_components/extended_openai_conversation/manifest.json b/custom_components/extended_openai_conversation/manifest.json index f496438..44fadee 100644 --- a/custom_components/extended_openai_conversation/manifest.json +++ b/custom_components/extended_openai_conversation/manifest.json @@ -17,7 +17,7 @@ "iot_class": "cloud_polling", "issue_tracker": "https://github.com/jekalmin/extended_openai_conversation/issues", "requirements": [ - "openai==0.27.2" + "openai~=1.3.8" ], "version": "0.0.9" } \ No newline at end of file