Skip to content

Commit

Permalink
Merge pull request #153 from jekalmin/v1.0.3
Browse files Browse the repository at this point in the history
1.0.3
  • Loading branch information
jekalmin authored Feb 21, 2024
2 parents f54662d + 690125d commit 45be8f9
Show file tree
Hide file tree
Showing 17 changed files with 359 additions and 138 deletions.
109 changes: 67 additions & 42 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,83 @@
"""The OpenAI Conversation integration."""
from __future__ import annotations

import json
import logging
from typing import Literal
import json
import yaml

from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai._exceptions import AuthenticationError, OpenAIError
from openai.types.chat.chat_completion import (
Choice,
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai._exceptions import OpenAIError, AuthenticationError
import yaml

from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL, ATTR_NAME
from homeassistant.const import ATTR_NAME, CONF_API_KEY, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
)

from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
intent,
template,
entity_registry as er,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid

from .const import (
CONF_API_VERSION,
CONF_ATTACH_USERNAME,
CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONF_FUNCTIONS,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_MAX_TOKENS,
CONF_ORGANIZATION,
CONF_PROMPT,
CONF_SKIP_AUTHENTICATION,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DOMAIN,
EVENT_CONVERSATION_FINISHED,
)

from .exceptions import (
FunctionNotFound,
FunctionLoadFailed,
ParseArgumentsFailed,
FunctionNotFound,
InvalidFunction,
ParseArgumentsFailed,
TokenLengthExceededError,
)

from .helpers import (
validate_authentication,
get_function_executor,
is_azure,
validate_authentication,
)

from .services import async_setup_services


_LOGGER = logging.getLogger(__name__)

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
Expand All @@ -104,6 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
api_key=entry.data[CONF_API_KEY],
base_url=entry.data.get(CONF_BASE_URL),
api_version=entry.data.get(CONF_API_VERSION),
organization=entry.data.get(CONF_ORGANIZATION),
skip_authentication=entry.data.get(
CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION
),
Expand Down Expand Up @@ -145,10 +144,13 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
api_key=entry.data[CONF_API_KEY],
azure_endpoint=base_url,
api_version=entry.data.get(CONF_API_VERSION),
organization=entry.data.get(CONF_ORGANIZATION),
)
else:
self.client = AsyncOpenAI(
api_key=entry.data[CONF_API_KEY], base_url=base_url
api_key=entry.data[CONF_API_KEY],
base_url=base_url,
organization=entry.data.get(CONF_ORGANIZATION),
)

@property
Expand Down Expand Up @@ -191,7 +193,7 @@ async def async_process(
messages.append(user_message)

try:
response = await self.query(user_input, messages, exposed_entities, 0)
query_response = await self.query(user_input, messages, exposed_entities, 0)
except OpenAIError as err:
_LOGGER.error(err)
intent_response = intent.IntentResponse(language=user_input.language)
Expand All @@ -213,11 +215,20 @@ async def async_process(
response=intent_response, conversation_id=conversation_id
)

messages.append(response.model_dump(exclude_none=True))
messages.append(query_response.message.model_dump(exclude_none=True))
self.history[conversation_id] = messages

self.hass.bus.async_fire(
EVENT_CONVERSATION_FINISHED,
{
"response": query_response.response.model_dump(),
"user_input": user_input,
"messages": messages,
},
)

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content)
intent_response.async_set_speech(query_response.message.content)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
Expand Down Expand Up @@ -317,7 +328,7 @@ async def query(
messages,
exposed_entities,
n_requests,
):
) -> OpenAIQueryResponse:
"""Process a sentence."""
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
Expand Down Expand Up @@ -366,14 +377,17 @@ async def query(
message = choice.message

if choice.finish_reason == "function_call":
message = await self.execute_function_call(
return await self.execute_function_call(
user_input, messages, message, exposed_entities, n_requests + 1
)
if choice.finish_reason == "tool_calls":
message = await self.execute_tool_calls(
return await self.execute_tool_calls(
user_input, messages, message, exposed_entities, n_requests + 1
)
return message
if choice.finish_reason == "length":
raise TokenLengthExceededError(response.usage.completion_tokens)

return OpenAIQueryResponse(response=response, message=message)

async def execute_function_call(
self,
Expand All @@ -382,7 +396,7 @@ async def execute_function_call(
message: ChatCompletionMessage,
exposed_entities,
n_requests,
):
) -> OpenAIQueryResponse:
function_name = message.function_call.name
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
Expand All @@ -407,7 +421,7 @@ async def execute_function(
exposed_entities,
n_requests,
function,
):
) -> OpenAIQueryResponse:
function_executor = get_function_executor(function["function"]["type"])

try:
Expand Down Expand Up @@ -435,7 +449,7 @@ async def execute_tool_calls(
message: ChatCompletionMessage,
exposed_entities,
n_requests,
):
) -> OpenAIQueryResponse:
messages.append(message.model_dump(exclude_none=True))
for tool in message.tool_calls:
function_name = tool.function.name
Expand Down Expand Up @@ -469,7 +483,7 @@ async def execute_tool_function(
tool,
exposed_entities,
function,
):
) -> OpenAIQueryResponse:
function_executor = get_function_executor(function["function"]["type"])

try:
Expand All @@ -481,3 +495,14 @@ async def execute_tool_function(
self.hass, function["function"], arguments, user_input, exposed_entities
)
return result


class OpenAIQueryResponse:
"""OpenAI query response value object."""

def __init__(
self, response: ChatCompletion, message: ChatCompletionMessage
) -> None:
"""Initialize OpenAI query response value object."""
self.response = response
self.message = message
45 changes: 24 additions & 21 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,62 @@

import logging
import types
import yaml
from types import MappingProxyType
from typing import Any

from openai._exceptions import APIConnectionError, AuthenticationError
import voluptuous as vol
import yaml

from homeassistant import config_entries
from homeassistant.const import CONF_NAME, CONF_API_KEY
from homeassistant.const import CONF_API_KEY, CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import (
BooleanSelector,
NumberSelector,
NumberSelectorConfig,
TemplateSelector,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectOptionDict,
SelectSelectorMode,
TemplateSelector,
)

from .helpers import validate_authentication

from .const import (
CONF_API_VERSION,
CONF_ATTACH_USERNAME,
CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONF_FUNCTIONS,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_MAX_TOKENS,
CONF_ORGANIZATION,
CONF_PROMPT,
CONF_SKIP_AUTHENTICATION,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
CONF_FUNCTIONS,
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
CONF_CONTEXT_THRESHOLD,
CONF_CONTEXT_TRUNCATE_STRATEGY,
CONTEXT_TRUNCATE_STRATEGIES,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_CONF_BASE_URL,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_TOKENS,
DEFAULT_NAME,
DEFAULT_PROMPT,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONF_BASE_URL,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DEFAULT_CONTEXT_THRESHOLD,
DEFAULT_CONTEXT_TRUNCATE_STRATEGY,
CONTEXT_TRUNCATE_STRATEGIES,
DOMAIN,
DEFAULT_NAME,
)
from .helpers import validate_authentication

_LOGGER = logging.getLogger(__name__)

Expand All @@ -68,6 +68,7 @@
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str,
vol.Optional(CONF_API_VERSION): str,
vol.Optional(CONF_ORGANIZATION): str,
vol.Optional(
CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION
): bool,
Expand Down Expand Up @@ -101,6 +102,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
api_key = data[CONF_API_KEY]
base_url = data.get(CONF_BASE_URL)
api_version = data.get(CONF_API_VERSION)
organization = data.get(CONF_ORGANIZATION)
skip_authentication = data.get(CONF_SKIP_AUTHENTICATION)

if base_url == DEFAULT_CONF_BASE_URL:
Expand All @@ -113,6 +115,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
api_key=api_key,
base_url=base_url,
api_version=api_version,
organization=organization,
skip_authentication=skip_authentication,
)

Expand Down
4 changes: 4 additions & 0 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

DOMAIN = "extended_openai_conversation"
DEFAULT_NAME = "Extended OpenAI Conversation"
CONF_ORGANIZATION = "organization"
CONF_BASE_URL = "base_url"
DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1"
CONF_API_VERSION = "api_version"
CONF_SKIP_AUTHENTICATION = "skip_authentication"
DEFAULT_SKIP_AUTHENTICATION = False

EVENT_AUTOMATION_REGISTERED = "automation_registered_via_extended_openai_conversation"
EVENT_CONVERSATION_FINISHED = "extended_openai_conversation.conversation.finished"

CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """I want you to act as smart home manager of Home Assistant.
Expand Down Expand Up @@ -93,3 +95,5 @@
DEFAULT_CONTEXT_TRUNCATE_STRATEGY = CONTEXT_TRUNCATE_STRATEGIES[0]["key"]

SERVICE_QUERY_IMAGE = "query_image"

CONF_PAYLOAD_TEMPLATE = "payload_template"
Loading

0 comments on commit 45be8f9

Please sign in to comment.