Skip to content

Commit

Permalink
Merge pull request #76 from jekalmin/v1.0.0
Browse files Browse the repository at this point in the history
1.0.0
  • Loading branch information
jekalmin authored Jan 4, 2024
2 parents d7ba5f8 + 5e9132b commit d85230b
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 88 deletions.
65 changes: 31 additions & 34 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
import json
import yaml

import openai
from openai import error
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import (
Choice,
ChatCompletion,
ChatCompletionMessage,
)
from openai._exceptions import OpenAIError, AuthenticationError

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -41,7 +46,6 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand Down Expand Up @@ -77,8 +81,7 @@
convert_to_template,
validate_authentication,
get_function_executor,
get_api_type,
get_default_model_key,
is_azure,
)


Expand All @@ -105,10 +108,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
),
)
except error.AuthenticationError as err:
except AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err)
return False
except error.OpenAIError as err:
except OpenAIError as err:
raise ConfigEntryNotReady(err) from err

agent = OpenAIAgent(hass, entry)
Expand Down Expand Up @@ -136,6 +139,11 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
self.hass = hass
self.entry = entry
self.history: dict[str, list[dict]] = {}
base_url = entry.data.get(CONF_BASE_URL)
if is_azure(base_url):
self.client = AsyncAzureOpenAI(api_key=entry.data[CONF_API_KEY], azure_endpoint=base_url, api_version=entry.data.get(CONF_API_VERSION))
else:
self.client = AsyncOpenAI(api_key=entry.data[CONF_API_KEY], base_url=base_url)

@property
def supported_languages(self) -> list[str] | Literal["*"]:
Expand Down Expand Up @@ -177,7 +185,7 @@ async def async_process(

try:
response = await self.query(user_input, messages, exposed_entities, 0)
except error.OpenAIError as err:
except OpenAIError as err:
_LOGGER.error(err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
Expand All @@ -198,11 +206,11 @@ async def async_process(
response=intent_response, conversation_id=conversation_id
)

messages.append(response)
messages.append(response.model_dump(exclude_none=True))
self.history[conversation_id] = messages

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
intent_response.async_set_speech(response.content)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
Expand Down Expand Up @@ -269,10 +277,6 @@ async def query(
n_requests,
):
"""Process a sentence."""
api_base = self.entry.data.get(CONF_BASE_URL)
api_key = self.entry.data[CONF_API_KEY]
api_type = get_api_type(api_base)
api_version = self.entry.data.get(CONF_API_VERSION)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
Expand All @@ -288,31 +292,24 @@ async def query(
functions = None
function_call = None

model_key = self.entry.options.get(
CONF_MODEL_KEY, get_default_model_key(api_base)
)
model_kwargs = {model_key: model}

_LOGGER.info("Prompt for %s: %s", model, messages)

response = await openai.ChatCompletion.acreate(
api_base=api_base,
api_key=api_key,
api_type=api_type,
api_version=api_version,
response: ChatCompletion = await self.client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=user_input.conversation_id,
functions=functions,
function_call=function_call,
**model_kwargs,
)


_LOGGER.info("Response %s", response)
message = response["choices"][0]["message"]
if message.get("function_call"):
choice: Choice = response.choices[0]
message = choice.message
if choice.finish_reason == "function_call":
message = await self.execute_function_call(
user_input, messages, message, exposed_entities, n_requests + 1
)
Expand All @@ -322,11 +319,11 @@ def execute_function_call(
self,
user_input: conversation.ConversationInput,
messages,
message,
message: ChatCompletionMessage,
exposed_entities,
n_requests,
):
function_name = message["function_call"]["name"]
function_name = message.function_call.name
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
None,
Expand All @@ -340,23 +337,23 @@ def execute_function_call(
n_requests,
function,
)
raise FunctionNotFound(message["function_call"]["name"])
raise FunctionNotFound(function_name)

async def execute_function(
self,
user_input: conversation.ConversationInput,
messages,
message,
message: ChatCompletionMessage,
exposed_entities,
n_requests,
function,
):
function_executor = get_function_executor(function["function"]["type"])

try:
arguments = json.loads(message["function_call"]["arguments"])
arguments = json.loads(message.function_call.arguments)
except json.decoder.JSONDecodeError as err:
raise ParseArgumentsFailed(message["function_call"]["arguments"]) from err
raise ParseArgumentsFailed(message.function_call.arguments) from err

result = await function_executor.execute(
self.hass, function["function"], arguments, user_input, exposed_entities
Expand All @@ -365,7 +362,7 @@ async def execute_function(
messages.append(
{
"role": "function",
"name": message["function_call"]["name"],
"name": message.function_call.name,
"content": str(result),
}
)
Expand Down
24 changes: 4 additions & 20 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import MappingProxyType
from typing import Any

from openai import error
from openai._exceptions import APIConnectionError, AuthenticationError
import voluptuous as vol

from homeassistant import config_entries
Expand All @@ -26,7 +26,7 @@
SelectSelectorMode,
)

from .helpers import validate_authentication, get_default_model_key
from .helpers import validate_authentication

from .const import (
CONF_ATTACH_USERNAME,
Expand All @@ -40,8 +40,6 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_MODEL_KEY,
MODEL_KEYS,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand Down Expand Up @@ -128,9 +126,9 @@ async def async_step_user(

try:
await validate_input(self.hass, user_input)
except error.APIConnectionError:
except APIConnectionError:
errors["base"] = "cannot_connect"
except error.AuthenticationError:
except AuthenticationError:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
Expand Down Expand Up @@ -224,18 +222,4 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
description={"suggested_value": options.get(CONF_ATTACH_USERNAME)},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
vol.Optional(
CONF_MODEL_KEY,
description={"suggested_value": options.get(CONF_MODEL_KEY)},
default=get_default_model_key(
self.config_entry.data.get(CONF_BASE_URL)
),
): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(value=key, label=key) for key in MODEL_KEYS
],
mode=SelectSelectorMode.DROPDOWN,
)
),
}
4 changes: 1 addition & 3 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Do not restate or appreciate what user says, rather make a quick inquiry.
"""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo-1106"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
Expand Down Expand Up @@ -84,5 +84,3 @@
]
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False
CONF_MODEL_KEY = "model_key"
MODEL_KEYS = ["model", "engine"]
32 changes: 10 additions & 22 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import yaml
import time
import sqlite3
import openai
from openai import AsyncOpenAI, AsyncAzureOpenAI
import re
import voluptuous as vol
from functools import partial
from bs4 import BeautifulSoup
from typing import Any
from urllib import parse
Expand Down Expand Up @@ -77,17 +76,10 @@ def get_function_executor(value: str):
return function_executor


def get_api_type(base_url: str):
def is_azure(base_url: str):
if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url):
return "azure"
return None


def get_default_model_key(base_url: str):
is_azure = get_api_type(base_url) == "azure"
if is_azure:
return "engine"
return "model"
return True
return False


def convert_to_template(
Expand Down Expand Up @@ -151,16 +143,12 @@ async def validate_authentication(
if skip_authentication:
return

await hass.async_add_executor_job(
partial(
openai.Model.list,
api_type=get_api_type(base_url),
api_key=api_key,
api_version=api_version,
api_base=base_url,
request_timeout=10,
)
)
if is_azure(base_url):
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
else:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)

await client.models.list(timeout=10)


class FunctionExecutor(ABC):
Expand Down
4 changes: 2 additions & 2 deletions custom_components/extended_openai_conversation/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/jekalmin/extended_openai_conversation/issues",
"requirements": [
"openai==0.27.2"
"openai~=1.3.8"
],
"version": "0.0.9"
"version": "1.0.0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"user": {
"data": {
"name": "Naam",
"api_key": "API-sleutel",
"base_url": "Basis-URL",
"api_version": "Api Version",
"skip_authentication": "Skip Authentication"
"api_key": "API Sleutel",
"base_url": "Basis URL",
"api_version": "API Version",
"skip_authentication": "Authenticatie overslaan"
}
}
}
Expand All @@ -28,8 +28,8 @@
"top_p": "Top P",
"max_function_calls_per_conversation": "Maximale keren functies mogen worden aangeroepen per conversatie",
"functions": "Functies",
"attach_username": "Attach Username to Message",
"model_key": "Model Key"
"attach_username": "Gebruikersnaam aan bericht toevoegen",
"model_key": "Model Sleutel"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"config": {
"error": {
"cannot_connect": "Nie udało się połączyć",
"invalid_auth": "Błąd authentykacji",
"unknown": "Nieznany błąd"
},
"step": {
"user": {
"data": {
"name": "Imię",
"api_key": "Klucz API",
"base_url": "Bazowy URL",
"api_version": "Wersja API",
"skip_authentication": "Pomiń authentykację"
}
}
}
},
"options": {
"step": {
"init": {
"data": {
"max_tokens": "Maksymalna ilość tokenów w odpowiedzi",
"model": "Model",
"prompt": "Prompt",
"temperature": "Temperatura",
"top_p": "Top P",
"max_function_calls_per_conversation": "Maksymalna ilość wywołań funkcji na rozmowę",
"functions": "Funkcje",
"attach_username": "Dodaj nazwę użytkownika do wiadomości",
"model_key": "Klucz modelu"
}
}
}
}
}
3 changes: 2 additions & 1 deletion hacs.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"name": "extended_openai_conversation",
"render_readme": true
"render_readme": true,
"homeassistant": "2024.1.0b0"
}

0 comments on commit d85230b

Please sign in to comment.