Skip to content

Commit

Permalink
bump openai version from "0.27.2" to "1.3.8"
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Jan 3, 2024
1 parent b880613 commit dad7dfd
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 79 deletions.
65 changes: 31 additions & 34 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
import json
import yaml

import openai
from openai import error
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import (
Choice,
ChatCompletion,
ChatCompletionMessage,
)
from openai._exceptions import OpenAIError, AuthenticationError

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -41,7 +46,6 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand Down Expand Up @@ -77,8 +81,7 @@
convert_to_template,
validate_authentication,
get_function_executor,
get_api_type,
get_default_model_key,
is_azure,
)


Expand All @@ -105,10 +108,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
),
)
except error.AuthenticationError as err:
except AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err)
return False
except error.OpenAIError as err:
except OpenAIError as err:
raise ConfigEntryNotReady(err) from err

agent = OpenAIAgent(hass, entry)
Expand Down Expand Up @@ -136,6 +139,11 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
self.hass = hass
self.entry = entry
self.history: dict[str, list[dict]] = {}
base_url = entry.data.get(CONF_BASE_URL)
if is_azure(base_url):
self.client = AsyncAzureOpenAI(api_key=entry.data[CONF_API_KEY], azure_endpoint=base_url, api_version=entry.data.get(CONF_API_VERSION))
else:
self.client = AsyncOpenAI(api_key=entry.data[CONF_API_KEY], base_url=base_url)

@property
def supported_languages(self) -> list[str] | Literal["*"]:
Expand Down Expand Up @@ -177,7 +185,7 @@ async def async_process(

try:
response = await self.query(user_input, messages, exposed_entities, 0)
except error.OpenAIError as err:
except OpenAIError as err:
_LOGGER.error(err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
Expand All @@ -198,11 +206,11 @@ async def async_process(
response=intent_response, conversation_id=conversation_id
)

messages.append(response)
messages.append(response.model_dump(exclude_none=True))
self.history[conversation_id] = messages

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
intent_response.async_set_speech(response.content)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
Expand Down Expand Up @@ -269,10 +277,6 @@ async def query(
n_requests,
):
"""Process a sentence."""
api_base = self.entry.data.get(CONF_BASE_URL)
api_key = self.entry.data[CONF_API_KEY]
api_type = get_api_type(api_base)
api_version = self.entry.data.get(CONF_API_VERSION)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
Expand All @@ -288,31 +292,24 @@ async def query(
functions = None
function_call = None

model_key = self.entry.options.get(
CONF_MODEL_KEY, get_default_model_key(api_base)
)
model_kwargs = {model_key: model}

_LOGGER.info("Prompt for %s: %s", model, messages)

response = await openai.ChatCompletion.acreate(
api_base=api_base,
api_key=api_key,
api_type=api_type,
api_version=api_version,
response: ChatCompletion = await self.client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=user_input.conversation_id,
functions=functions,
function_call=function_call,
**model_kwargs,
)


_LOGGER.info("Response %s", response)
message = response["choices"][0]["message"]
if message.get("function_call"):
choice: Choice = response.choices[0]
message = choice.message
if choice.finish_reason == "function_call":
message = await self.execute_function_call(
user_input, messages, message, exposed_entities, n_requests + 1
)
Expand All @@ -322,11 +319,11 @@ def execute_function_call(
self,
user_input: conversation.ConversationInput,
messages,
message,
message: ChatCompletionMessage,
exposed_entities,
n_requests,
):
function_name = message["function_call"]["name"]
function_name = message.function_call.name
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
None,
Expand All @@ -340,23 +337,23 @@ def execute_function_call(
n_requests,
function,
)
raise FunctionNotFound(message["function_call"]["name"])
raise FunctionNotFound(function_name)

async def execute_function(
self,
user_input: conversation.ConversationInput,
messages,
message,
message: ChatCompletionMessage,
exposed_entities,
n_requests,
function,
):
function_executor = get_function_executor(function["function"]["type"])

try:
arguments = json.loads(message["function_call"]["arguments"])
arguments = json.loads(message.function_call.arguments)
except json.decoder.JSONDecodeError as err:
raise ParseArgumentsFailed(message["function_call"]["arguments"]) from err
raise ParseArgumentsFailed(message.function_call.arguments) from err

result = await function_executor.execute(
self.hass, function["function"], arguments, user_input, exposed_entities
Expand All @@ -365,7 +362,7 @@ async def execute_function(
messages.append(
{
"role": "function",
"name": message["function_call"]["name"],
"name": message.function_call.name,
"content": str(result),
}
)
Expand Down
24 changes: 4 additions & 20 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import MappingProxyType
from typing import Any

from openai import error
from openai._exceptions import APIConnectionError, AuthenticationError
import voluptuous as vol

from homeassistant import config_entries
Expand All @@ -26,7 +26,7 @@
SelectSelectorMode,
)

from .helpers import validate_authentication, get_default_model_key
from .helpers import validate_authentication

from .const import (
CONF_ATTACH_USERNAME,
Expand All @@ -40,8 +40,6 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
MODEL_KEYS,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand Down Expand Up @@ -128,9 +126,9 @@ async def async_step_user(

try:
await validate_input(self.hass, user_input)
except error.APIConnectionError:
except APIConnectionError:
errors["base"] = "cannot_connect"
except error.AuthenticationError:
except AuthenticationError:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
Expand Down Expand Up @@ -224,18 +222,4 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
description={"suggested_value": options.get(CONF_ATTACH_USERNAME)},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
vol.Optional(
CONF_MODEL_KEY,
description={"suggested_value": options.get(CONF_MODEL_KEY)},
default=get_default_model_key(
self.config_entry.data.get(CONF_BASE_URL)
),
): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(value=key, label=key) for key in MODEL_KEYS
],
mode=SelectSelectorMode.DROPDOWN,
)
),
}
2 changes: 0 additions & 2 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,3 @@
]
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False
CONF_MODEL_KEY = "model_key"
MODEL_KEYS = ["model", "engine"]
32 changes: 10 additions & 22 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import yaml
import time
import sqlite3
import openai
from openai import AsyncOpenAI, AsyncAzureOpenAI
import re
import voluptuous as vol
from functools import partial
from bs4 import BeautifulSoup
from typing import Any
from urllib import parse
Expand Down Expand Up @@ -77,17 +76,10 @@ def get_function_executor(value: str):
return function_executor


def get_api_type(base_url: str):
def is_azure(base_url: str):
if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url):
return "azure"
return None


def get_default_model_key(base_url: str):
is_azure = get_api_type(base_url) == "azure"
if is_azure:
return "engine"
return "model"
return True
return False


def convert_to_template(
Expand Down Expand Up @@ -151,16 +143,12 @@ async def validate_authentication(
if skip_authentication:
return

await hass.async_add_executor_job(
partial(
openai.Model.list,
api_type=get_api_type(base_url),
api_key=api_key,
api_version=api_version,
api_base=base_url,
request_timeout=10,
)
)
if is_azure(base_url):
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
else:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)

await client.models.list(timeout=10)


class FunctionExecutor(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/jekalmin/extended_openai_conversation/issues",
"requirements": [
"openai==0.27.2"
"openai~=1.3.8"
],
"version": "0.0.9"
}

0 comments on commit dad7dfd

Please sign in to comment.