From 89738b0a347f45257f313466accb07e921ca8bac Mon Sep 17 00:00:00 2001 From: jekalmin Date: Tue, 19 Dec 2023 22:09:56 +0900 Subject: [PATCH 1/6] add hacs validation --- .github/workflows/hassfest.yaml | 14 ++++++++++++++ .github/workflows/validate.yaml | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 .github/workflows/hassfest.yaml create mode 100644 .github/workflows/validate.yaml diff --git a/.github/workflows/hassfest.yaml b/.github/workflows/hassfest.yaml new file mode 100644 index 0000000..7df6b77 --- /dev/null +++ b/.github/workflows/hassfest.yaml @@ -0,0 +1,14 @@ +name: Validate with hassfest + +on: + push: + pull_request: + schedule: + - cron: "0 0 * * *" + +jobs: + validate: + runs-on: "ubuntu-latest" + steps: + - uses: "actions/checkout@v4" + - uses: "home-assistant/actions/hassfest@master" \ No newline at end of file diff --git a/.github/workflows/validate.yaml b/.github/workflows/validate.yaml new file mode 100644 index 0000000..6632c1d --- /dev/null +++ b/.github/workflows/validate.yaml @@ -0,0 +1,18 @@ +name: Validate + +on: + push: + pull_request: + schedule: + - cron: "0 0 * * *" + workflow_dispatch: + +jobs: + validate-hacs: + runs-on: "ubuntu-latest" + steps: + - uses: "actions/checkout@v3" + - name: HACS validation + uses: "hacs/action@main" + with: + category: "integration" \ No newline at end of file From 903d83c9355eec0c75d2356e9405b08a71d4046b Mon Sep 17 00:00:00 2001 From: jekalmin Date: Tue, 19 Dec 2023 22:18:23 +0900 Subject: [PATCH 2/6] add "history", "recorder", "rest", "scrape" to dependencies, add issue_tracker --- .../extended_openai_conversation/manifest.json | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/custom_components/extended_openai_conversation/manifest.json b/custom_components/extended_openai_conversation/manifest.json index b8f5b97..f496438 100644 --- a/custom_components/extended_openai_conversation/manifest.json +++ b/custom_components/extended_openai_conversation/manifest.json @@ -6,11 +6,16 @@ ], "config_flow": true, "dependencies": [ - "conversation" + "conversation", + "history", + "recorder", + "rest", + "scrape" ], "documentation": "https://github.com/jekalmin/extended_openai_conversation", "integration_type": "service", "iot_class": "cloud_polling", + "issue_tracker": "https://github.com/jekalmin/extended_openai_conversation/issues", "requirements": [ "openai==0.27.2" ], From 1aba718f6340c7537140fc403c632185b1a438fa Mon Sep 17 00:00:00 2001 From: jekalmin Date: Wed, 13 Dec 2023 00:53:58 +0900 Subject: [PATCH 3/6] [#42] add azure openai (test needed) --- .../extended_openai_conversation/__init__.py | 16 ++++++++++++++-- .../extended_openai_conversation/config_flow.py | 7 ++++++- .../extended_openai_conversation/const.py | 1 + .../extended_openai_conversation/helpers.py | 9 +++++++-- .../extended_openai_conversation/strings.json | 3 ++- .../translations/de.json | 3 ++- .../translations/en.json | 3 ++- .../translations/ko.json | 3 ++- .../translations/nl.json | 3 ++- 9 files changed, 38 insertions(+), 10 deletions(-) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index 7fbcf57..d39d42b 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -1,6 +1,7 @@ """The OpenAI Conversation integration.""" from __future__ import annotations +import re import logging from typing import Literal import json @@ -39,6 +40,7 @@ CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, CONF_FUNCTIONS, CONF_BASE_URL, + CONF_API_VERSION, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -79,6 +81,7 @@ _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com" # hass.data key for agent. @@ -93,6 +96,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass=hass, api_key=entry.data[CONF_API_KEY], base_url=entry.data.get(CONF_BASE_URL), + api_version=entry.data.get(CONF_API_VERSION), ) except error.AuthenticationError as err: _LOGGER.error("Invalid API key: %s", err) @@ -258,6 +262,12 @@ async def query( n_requests, ): """Process a sentence.""" + api_base = self.entry.data.get(CONF_BASE_URL) + api_key = self.entry.data[CONF_API_KEY] + api_type = None + if api_base and re.search(AZURE_DOMAIN_PATTERN, api_base): + api_type = "azure" + api_version = self.entry.data.get(CONF_API_VERSION) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) @@ -274,8 +284,10 @@ async def query( _LOGGER.info("Prompt for %s: %s", model, messages) response = await openai.ChatCompletion.acreate( - api_base=self.entry.data.get(CONF_BASE_URL), - api_key=self.entry.data[CONF_API_KEY], + api_base=api_base, + api_key=api_key, + api_type=api_type, + api_version=api_version, model=model, messages=messages, max_tokens=max_tokens, diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index 94e17d3..a632e3e 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -34,6 +34,7 @@ CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, CONF_FUNCTIONS, CONF_BASE_URL, + CONF_API_VERSION, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -54,6 +55,7 @@ vol.Optional(CONF_NAME): str, vol.Required(CONF_API_KEY): str, vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str, + vol.Optional(CONF_API_VERSION): str, } ) @@ -80,13 +82,16 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """ api_key = data[CONF_API_KEY] base_url = data.get(CONF_BASE_URL) + api_version = data.get(CONF_API_VERSION) if base_url == DEFAULT_CONF_BASE_URL: # Do not set base_url if using OpenAI for case of OpenAI's base_url change base_url = None data.pop(CONF_BASE_URL) - await validate_authentication(hass=hass, api_key=api_key, base_url=base_url) + await validate_authentication( + hass=hass, api_key=api_key, base_url=base_url, api_version=api_version + ) class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index 3203189..778d8fc 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -77,5 +77,6 @@ ] CONF_BASE_URL = "base_url" DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1" +CONF_API_VERSION = "api_version" CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 9da483d..86f6e58 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -122,13 +122,18 @@ def _get_rest_data(hass, rest_config, arguments): async def validate_authentication( - hass: HomeAssistant, api_key: str, base_url: str + hass: HomeAssistant, api_key: str, base_url: str, api_version: str ) -> None: if not base_url: base_url = DEFAULT_CONF_BASE_URL + + url = f"{base_url}/models" + if api_version: + url = f"{url}?api-version={api_version}" + session = async_get_clientsession(hass) response = await session.get( - f"{base_url}/models", + url, headers={"Authorization": f"Bearer {api_key}"}, timeout=10, ) diff --git a/custom_components/extended_openai_conversation/strings.json b/custom_components/extended_openai_conversation/strings.json index 4db21d6..06b5512 100644 --- a/custom_components/extended_openai_conversation/strings.json +++ b/custom_components/extended_openai_conversation/strings.json @@ -5,7 +5,8 @@ "data": { "name": "[%key:common::config_flow::data::name%]", "api_key": "[%key:common::config_flow::data::api_key%]", - "base_url": "[%key:common::config_flow::data::base_url%]" + "base_url": "[%key:common::config_flow::data::base_url%]", + "api_version": "[%key:common::config_flow::data::api_version%]" } } }, diff --git a/custom_components/extended_openai_conversation/translations/de.json b/custom_components/extended_openai_conversation/translations/de.json index a6b2346..5aa8899 100644 --- a/custom_components/extended_openai_conversation/translations/de.json +++ b/custom_components/extended_openai_conversation/translations/de.json @@ -10,7 +10,8 @@ "data": { "name": "Name", "api_key": "API Key", - "base_url": "Base Url" + "base_url": "Base Url", + "api_version": "Api Version" } } } diff --git a/custom_components/extended_openai_conversation/translations/en.json b/custom_components/extended_openai_conversation/translations/en.json index 7f0a07e..61644a3 100644 --- a/custom_components/extended_openai_conversation/translations/en.json +++ b/custom_components/extended_openai_conversation/translations/en.json @@ -10,7 +10,8 @@ "data": { "name": "Name", "api_key": "API Key", - "base_url": "Base Url" + "base_url": "Base Url", + "api_version": "Api Version" } } } diff --git a/custom_components/extended_openai_conversation/translations/ko.json b/custom_components/extended_openai_conversation/translations/ko.json index 7f0a07e..61644a3 100644 --- a/custom_components/extended_openai_conversation/translations/ko.json +++ b/custom_components/extended_openai_conversation/translations/ko.json @@ -10,7 +10,8 @@ "data": { "name": "Name", "api_key": "API Key", - "base_url": "Base Url" + "base_url": "Base Url", + "api_version": "Api Version" } } } diff --git a/custom_components/extended_openai_conversation/translations/nl.json b/custom_components/extended_openai_conversation/translations/nl.json index 77e4139..5e0c42a 100644 --- a/custom_components/extended_openai_conversation/translations/nl.json +++ b/custom_components/extended_openai_conversation/translations/nl.json @@ -10,7 +10,8 @@ "data": { "name": "Naam", "api_key": "API-sleutel", - "base_url": "Basis-URL" + "base_url": "Basis-URL", + "api_version": "Api Version" } } } From bdc9c2ca8c2432b1f2f8562ed281d13ebf9528e5 Mon Sep 17 00:00:00 2001 From: jekalmin Date: Sun, 24 Dec 2023 17:08:25 +0900 Subject: [PATCH 4/6] add "skip_authentication" and "model_key" options --- .../extended_openai_conversation/__init__.py | 20 ++- .../config_flow.py | 145 +++++++++++------- .../extended_openai_conversation/const.py | 13 +- .../extended_openai_conversation/helpers.py | 48 +++--- .../extended_openai_conversation/strings.json | 6 +- .../translations/de.json | 7 +- .../translations/en.json | 6 +- .../translations/ko.json | 6 +- .../translations/nl.json | 7 +- 9 files changed, 162 insertions(+), 96 deletions(-) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index d39d42b..53d89bc 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -1,7 +1,6 @@ """The OpenAI Conversation integration.""" from __future__ import annotations -import re import logging from typing import Literal import json @@ -41,6 +40,8 @@ CONF_FUNCTIONS, CONF_BASE_URL, CONF_API_VERSION, + CONF_SKIP_AUTHENTICATION, + CONF_MODEL_KEY, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -49,6 +50,8 @@ DEFAULT_TOP_P, DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, + DEFAULT_SKIP_AUTHENTICATION, + DEFAULT_MODEL_KEY, DOMAIN, ) @@ -75,6 +78,7 @@ convert_to_template, validate_authentication, get_function_executor, + get_api_type, ) @@ -97,6 +101,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: api_key=entry.data[CONF_API_KEY], base_url=entry.data.get(CONF_BASE_URL), api_version=entry.data.get(CONF_API_VERSION), + skip_authentication=entry.data.get( + CONF_SKIP_AUTHENTICATION, DEFAULT_SKIP_AUTHENTICATION + ), ) except error.AuthenticationError as err: _LOGGER.error("Invalid API key: %s", err) @@ -264,9 +271,7 @@ async def query( """Process a sentence.""" api_base = self.entry.data.get(CONF_BASE_URL) api_key = self.entry.data[CONF_API_KEY] - api_type = None - if api_base and re.search(AZURE_DOMAIN_PATTERN, api_base): - api_type = "azure" + api_type = get_api_type(api_base) api_version = self.entry.data.get(CONF_API_VERSION) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) @@ -279,7 +284,9 @@ async def query( DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, ): function_call = "none" - response_format = {"type": "text"} + model_kwargs = { + self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model + } _LOGGER.info("Prompt for %s: %s", model, messages) @@ -288,7 +295,6 @@ async def query( api_key=api_key, api_type=api_type, api_version=api_version, - model=model, messages=messages, max_tokens=max_tokens, top_p=top_p, @@ -296,7 +302,7 @@ async def query( user=user_input.conversation_id, functions=functions, function_call=function_call, - response_format=response_format, + **model_kwargs, ) _LOGGER.info("Response %s", response) diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index a632e3e..3e2184d 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -20,9 +20,13 @@ NumberSelectorConfig, TemplateSelector, AttributeSelector, + SelectSelector, + SelectSelectorConfig, + SelectOptionDict, + SelectSelectorMode, ) -from .helpers import validate_authentication +from .helpers import validate_authentication, get_api_type from .const import ( CONF_ATTACH_USERNAME, @@ -35,6 +39,9 @@ CONF_FUNCTIONS, CONF_BASE_URL, CONF_API_VERSION, + CONF_SKIP_AUTHENTICATION, + CONF_MODEL_KEY, + MODEL_KEYS, DEFAULT_ATTACH_USERNAME, DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, @@ -44,6 +51,8 @@ DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, DEFAULT_CONF_BASE_URL, + DEFAULT_SKIP_AUTHENTICATION, + DEFAULT_MODEL_KEY, DOMAIN, DEFAULT_NAME, ) @@ -56,6 +65,9 @@ vol.Required(CONF_API_KEY): str, vol.Optional(CONF_BASE_URL, default=DEFAULT_CONF_BASE_URL): str, vol.Optional(CONF_API_VERSION): str, + vol.Optional( + CONF_SKIP_AUTHENTICATION, default=DEFAULT_SKIP_AUTHENTICATION + ): bool, } ) @@ -83,6 +95,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: api_key = data[CONF_API_KEY] base_url = data.get(CONF_BASE_URL) api_version = data.get(CONF_API_VERSION) + skip_authentication = data.get(CONF_SKIP_AUTHENTICATION) if base_url == DEFAULT_CONF_BASE_URL: # Do not set base_url if using OpenAI for case of OpenAI's base_url change @@ -90,7 +103,11 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: data.pop(CONF_BASE_URL) await validate_authentication( - hass=hass, api_key=api_key, base_url=base_url, api_version=api_version + hass=hass, + api_key=api_key, + base_url=base_url, + api_version=api_version, + skip_authentication=skip_authentication, ) @@ -151,63 +168,77 @@ async def async_step_init( return self.async_create_entry( title=user_input.get(CONF_NAME, DEFAULT_NAME), data=user_input ) - schema = openai_config_option_schema(self.config_entry.options) + schema = self.openai_config_option_schema(self.config_entry.options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) - -def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: - """Return a schema for OpenAI completion options.""" - if not options: - options = DEFAULT_OPTIONS - return { - vol.Optional( - CONF_PROMPT, - description={"suggested_value": options[CONF_PROMPT]}, - default=DEFAULT_PROMPT, - ): TemplateSelector(), - vol.Optional( - CONF_CHAT_MODEL, - description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, - vol.Optional( - CONF_MAX_TOKENS, - description={"suggested_value": options[CONF_MAX_TOKENS]}, - default=DEFAULT_MAX_TOKENS, - ): int, - vol.Optional( - CONF_TOP_P, - description={"suggested_value": options[CONF_TOP_P]}, - default=DEFAULT_TOP_P, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options[CONF_TEMPERATURE]}, - default=DEFAULT_TEMPERATURE, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, - description={ - "suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION] - }, - default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, - ): int, - vol.Optional( - CONF_FUNCTIONS, - description={"suggested_value": options.get(CONF_FUNCTIONS)}, - default=DEFAULT_CONF_FUNCTIONS_STR, - ): TemplateSelector(), - vol.Optional( - CONF_ATTACH_USERNAME, - description={ - "suggested_value": options.get(CONF_ATTACH_USERNAME) - }, - default=DEFAULT_ATTACH_USERNAME, - ): BooleanSelector(), - } + def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> dict: + """Return a schema for OpenAI completion options.""" + if not options: + options = DEFAULT_OPTIONS + + is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure" + + return { + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options[CONF_PROMPT]}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), + vol.Optional( + CONF_CHAT_MODEL, + description={ + # New key in HA 2023.4 + "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + }, + default=DEFAULT_CHAT_MODEL, + ): str, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options[CONF_MAX_TOKENS]}, + default=DEFAULT_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options[CONF_TOP_P]}, + default=DEFAULT_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options[CONF_TEMPERATURE]}, + default=DEFAULT_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION, + description={ + "suggested_value": options[CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION] + }, + default=DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, + ): int, + vol.Optional( + CONF_FUNCTIONS, + description={"suggested_value": options.get(CONF_FUNCTIONS)}, + default=DEFAULT_CONF_FUNCTIONS_STR, + ): TemplateSelector(), + vol.Optional( + CONF_ATTACH_USERNAME, + description={ + "suggested_value": options.get(CONF_ATTACH_USERNAME) + }, + default=DEFAULT_ATTACH_USERNAME, + ): BooleanSelector(), + vol.Optional( + CONF_MODEL_KEY, + description={"suggested_value": options.get(CONF_MODEL_KEY)}, + default="engine" if is_azure else DEFAULT_MODEL_KEY, + ): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict(value=key, label=key) for key in MODEL_KEYS + ], + mode=SelectSelectorMode.DROPDOWN, + ) + ), + } diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index 778d8fc..56a2233 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -2,7 +2,14 @@ DOMAIN = "extended_openai_conversation" DEFAULT_NAME = "Extended OpenAI Conversation" +CONF_BASE_URL = "base_url" +DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1" +CONF_API_VERSION = "api_version" +CONF_SKIP_AUTHENTICATION = "skip_authentication" +DEFAULT_SKIP_AUTHENTICATION = False + EVENT_AUTOMATION_REGISTERED = "automation_registered_via_extended_openai_conversation" + CONF_PROMPT = "prompt" DEFAULT_PROMPT = """I want you to act as smart home manager of Home Assistant. I will provide information of smart home along with a question, you will truthfully make correction or answer using information provided in one sentence in everyday language. @@ -75,8 +82,8 @@ "function": {"type": "native", "name": "execute_service"}, } ] -CONF_BASE_URL = "base_url" -DEFAULT_CONF_BASE_URL = "https://api.openai.com/v1" -CONF_API_VERSION = "api_version" CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False +CONF_MODEL_KEY = "model_key" +DEFAULT_MODEL_KEY = "model" +MODEL_KEYS = ["model", "engine"] diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index 86f6e58..a500ff1 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -4,11 +4,12 @@ import yaml import time import sqlite3 +import openai +import re import voluptuous as vol +from functools import partial from bs4 import BeautifulSoup from typing import Any -from homeassistant.helpers.aiohttp_client import async_get_clientsession -from openai.error import AuthenticationError from urllib import parse from datetime import timedelta @@ -63,6 +64,9 @@ _LOGGER = logging.getLogger(__name__) +AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com" + + def get_function_executor(value: str): function_executor = FUNCTION_EXECUTORS.get(value) if function_executor is None: @@ -70,6 +74,12 @@ def get_function_executor(value: str): return function_executor +def get_api_type(base_url: str): + if base_url and re.search(AZURE_DOMAIN_PATTERN, base_url): + return "azure" + return None + + def convert_to_template( settings, template_keys=["data", "event_data", "target", "service"], @@ -122,25 +132,25 @@ def _get_rest_data(hass, rest_config, arguments): async def validate_authentication( - hass: HomeAssistant, api_key: str, base_url: str, api_version: str + hass: HomeAssistant, + api_key: str, + base_url: str, + api_version: str, + skip_authentication=False, ) -> None: - if not base_url: - base_url = DEFAULT_CONF_BASE_URL - - url = f"{base_url}/models" - if api_version: - url = f"{url}?api-version={api_version}" - - session = async_get_clientsession(hass) - response = await session.get( - url, - headers={"Authorization": f"Bearer {api_key}"}, - timeout=10, + if skip_authentication: + return + + await hass.async_add_executor_job( + partial( + openai.Model.list, + api_type=get_api_type(base_url), + api_key=api_key, + api_version=api_version, + api_base=base_url, + request_timeout=10, + ) ) - if response.status == 401: - raise AuthenticationError() - - response.raise_for_status() class FunctionExecutor(ABC): diff --git a/custom_components/extended_openai_conversation/strings.json b/custom_components/extended_openai_conversation/strings.json index 06b5512..752e35f 100644 --- a/custom_components/extended_openai_conversation/strings.json +++ b/custom_components/extended_openai_conversation/strings.json @@ -6,7 +6,8 @@ "name": "[%key:common::config_flow::data::name%]", "api_key": "[%key:common::config_flow::data::api_key%]", "base_url": "[%key:common::config_flow::data::base_url%]", - "api_version": "[%key:common::config_flow::data::api_version%]" + "api_version": "[%key:common::config_flow::data::api_version%]", + "skip_authentication": "[%key:common::config_flow::data::skip_authentication%]" } } }, @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/de.json b/custom_components/extended_openai_conversation/translations/de.json index 5aa8899..9c27154 100644 --- a/custom_components/extended_openai_conversation/translations/de.json +++ b/custom_components/extended_openai_conversation/translations/de.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -26,7 +27,9 @@ "temperature": "Temperatur", "top_p": "Top P", "max_function_calls_per_conversation": "Maximale Anzahl an Funktionsaufrufen pro Konversation", - "functions": "Funktionen" + "functions": "Funktionen", + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/en.json b/custom_components/extended_openai_conversation/translations/en.json index 61644a3..0a46be2 100644 --- a/custom_components/extended_openai_conversation/translations/en.json +++ b/custom_components/extended_openai_conversation/translations/en.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/ko.json b/custom_components/extended_openai_conversation/translations/ko.json index 61644a3..0a46be2 100644 --- a/custom_components/extended_openai_conversation/translations/ko.json +++ b/custom_components/extended_openai_conversation/translations/ko.json @@ -11,7 +11,8 @@ "name": "Name", "api_key": "API Key", "base_url": "Base Url", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -27,7 +28,8 @@ "top_p": "Top P", "max_function_calls_per_conversation": "Maximum function calls per conversation", "functions": "Functions", - "attach_username": "Attach Username to Message" + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } diff --git a/custom_components/extended_openai_conversation/translations/nl.json b/custom_components/extended_openai_conversation/translations/nl.json index 5e0c42a..3fa6114 100644 --- a/custom_components/extended_openai_conversation/translations/nl.json +++ b/custom_components/extended_openai_conversation/translations/nl.json @@ -11,7 +11,8 @@ "name": "Naam", "api_key": "API-sleutel", "base_url": "Basis-URL", - "api_version": "Api Version" + "api_version": "Api Version", + "skip_authentication": "Skip Authentication" } } } @@ -26,7 +27,9 @@ "temperature": "Temperatuur", "top_p": "Top P", "max_function_calls_per_conversation": "Maximale keren functies mogen worden aangeroepen per conversatie", - "functions": "Functies" + "functions": "Functies", + "attach_username": "Attach Username to Message", + "model_key": "Model Key" } } } From d871b603613bc023cf0ef2845029ae3552cc9965 Mon Sep 17 00:00:00 2001 From: jekalmin Date: Sun, 24 Dec 2023 18:37:56 +0900 Subject: [PATCH 5/6] [#42] fix default model key --- .../extended_openai_conversation/__init__.py | 10 ++++++---- .../extended_openai_conversation/config_flow.py | 13 +++++-------- .../extended_openai_conversation/const.py | 1 - .../extended_openai_conversation/helpers.py | 12 +++++++++++- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index 53d89bc..f3acefe 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -51,7 +51,6 @@ DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, DEFAULT_SKIP_AUTHENTICATION, - DEFAULT_MODEL_KEY, DOMAIN, ) @@ -79,6 +78,7 @@ validate_authentication, get_function_executor, get_api_type, + get_default_model_key, ) @@ -284,9 +284,11 @@ async def query( DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, ): function_call = "none" - model_kwargs = { - self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model - } + + model_key = self.entry.options.get( + CONF_MODEL_KEY, get_default_model_key(api_base) + ) + model_kwargs = {model_key: model} _LOGGER.info("Prompt for %s: %s", model, messages) diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index 3e2184d..168a06c 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -26,7 +26,7 @@ SelectSelectorMode, ) -from .helpers import validate_authentication, get_api_type +from .helpers import validate_authentication, get_default_model_key from .const import ( CONF_ATTACH_USERNAME, @@ -52,7 +52,6 @@ DEFAULT_CONF_FUNCTIONS, DEFAULT_CONF_BASE_URL, DEFAULT_SKIP_AUTHENTICATION, - DEFAULT_MODEL_KEY, DOMAIN, DEFAULT_NAME, ) @@ -179,8 +178,6 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di if not options: options = DEFAULT_OPTIONS - is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure" - return { vol.Optional( CONF_PROMPT, @@ -224,15 +221,15 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di ): TemplateSelector(), vol.Optional( CONF_ATTACH_USERNAME, - description={ - "suggested_value": options.get(CONF_ATTACH_USERNAME) - }, + description={"suggested_value": options.get(CONF_ATTACH_USERNAME)}, default=DEFAULT_ATTACH_USERNAME, ): BooleanSelector(), vol.Optional( CONF_MODEL_KEY, description={"suggested_value": options.get(CONF_MODEL_KEY)}, - default="engine" if is_azure else DEFAULT_MODEL_KEY, + default=get_default_model_key( + self.config_entry.data.get(CONF_BASE_URL) + ), ): SelectSelector( SelectSelectorConfig( options=[ diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index 56a2233..2a3fc10 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -85,5 +85,4 @@ CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False CONF_MODEL_KEY = "model_key" -DEFAULT_MODEL_KEY = "model" MODEL_KEYS = ["model", "engine"] diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index a500ff1..2a43d9d 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -58,7 +58,10 @@ FunctionNotFound, ) -from .const import DOMAIN, EVENT_AUTOMATION_REGISTERED, DEFAULT_CONF_BASE_URL +from .const import ( + DOMAIN, + EVENT_AUTOMATION_REGISTERED, +) _LOGGER = logging.getLogger(__name__) @@ -80,6 +83,13 @@ def get_api_type(base_url: str): return None +def get_default_model_key(base_url: str): + is_azure = get_api_type(base_url) == "azure" + if is_azure: + return "engine" + return "model" + + def convert_to_template( settings, template_keys=["data", "event_data", "target", "service"], From 16240af8c9e63aba1a498a43be8e419907ee8489 Mon Sep 17 00:00:00 2001 From: jekalmin Date: Sun, 24 Dec 2023 22:48:26 +0900 Subject: [PATCH 6/6] [#57] make functions optional --- custom_components/extended_openai_conversation/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index f3acefe..68c4d39 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -284,6 +284,9 @@ async def query( DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, ): function_call = "none" + if len(functions) == 0: + functions = None + function_call = None model_key = self.entry.options.get( CONF_MODEL_KEY, get_default_model_key(api_base)