From d871b603613bc023cf0ef2845029ae3552cc9965 Mon Sep 17 00:00:00 2001 From: jekalmin Date: Sun, 24 Dec 2023 18:37:56 +0900 Subject: [PATCH] [#42] fix default model key --- .../extended_openai_conversation/__init__.py | 10 ++++++---- .../extended_openai_conversation/config_flow.py | 13 +++++-------- .../extended_openai_conversation/const.py | 1 - .../extended_openai_conversation/helpers.py | 12 +++++++++++- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/custom_components/extended_openai_conversation/__init__.py b/custom_components/extended_openai_conversation/__init__.py index 53d89bc..f3acefe 100644 --- a/custom_components/extended_openai_conversation/__init__.py +++ b/custom_components/extended_openai_conversation/__init__.py @@ -51,7 +51,6 @@ DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, DEFAULT_CONF_FUNCTIONS, DEFAULT_SKIP_AUTHENTICATION, - DEFAULT_MODEL_KEY, DOMAIN, ) @@ -79,6 +78,7 @@ validate_authentication, get_function_executor, get_api_type, + get_default_model_key, ) @@ -284,9 +284,11 @@ async def query( DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION, ): function_call = "none" - model_kwargs = { - self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model - } + + model_key = self.entry.options.get( + CONF_MODEL_KEY, get_default_model_key(api_base) + ) + model_kwargs = {model_key: model} _LOGGER.info("Prompt for %s: %s", model, messages) diff --git a/custom_components/extended_openai_conversation/config_flow.py b/custom_components/extended_openai_conversation/config_flow.py index 3e2184d..168a06c 100644 --- a/custom_components/extended_openai_conversation/config_flow.py +++ b/custom_components/extended_openai_conversation/config_flow.py @@ -26,7 +26,7 @@ SelectSelectorMode, ) -from .helpers import validate_authentication, get_api_type +from .helpers import validate_authentication, get_default_model_key from .const import ( CONF_ATTACH_USERNAME, @@ -52,7 +52,6 @@ DEFAULT_CONF_FUNCTIONS, DEFAULT_CONF_BASE_URL, DEFAULT_SKIP_AUTHENTICATION, - DEFAULT_MODEL_KEY, DOMAIN, DEFAULT_NAME, ) @@ -179,8 +178,6 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di if not options: options = DEFAULT_OPTIONS - is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure" - return { vol.Optional( CONF_PROMPT, @@ -224,15 +221,15 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di ): TemplateSelector(), vol.Optional( CONF_ATTACH_USERNAME, - description={ - "suggested_value": options.get(CONF_ATTACH_USERNAME) - }, + description={"suggested_value": options.get(CONF_ATTACH_USERNAME)}, default=DEFAULT_ATTACH_USERNAME, ): BooleanSelector(), vol.Optional( CONF_MODEL_KEY, description={"suggested_value": options.get(CONF_MODEL_KEY)}, - default="engine" if is_azure else DEFAULT_MODEL_KEY, + default=get_default_model_key( + self.config_entry.data.get(CONF_BASE_URL) + ), ): SelectSelector( SelectSelectorConfig( options=[ diff --git a/custom_components/extended_openai_conversation/const.py b/custom_components/extended_openai_conversation/const.py index 56a2233..2a3fc10 100644 --- a/custom_components/extended_openai_conversation/const.py +++ b/custom_components/extended_openai_conversation/const.py @@ -85,5 +85,4 @@ CONF_ATTACH_USERNAME = "attach_username" DEFAULT_ATTACH_USERNAME = False CONF_MODEL_KEY = "model_key" -DEFAULT_MODEL_KEY = "model" MODEL_KEYS = ["model", "engine"] diff --git a/custom_components/extended_openai_conversation/helpers.py b/custom_components/extended_openai_conversation/helpers.py index a500ff1..2a43d9d 100644 --- a/custom_components/extended_openai_conversation/helpers.py +++ b/custom_components/extended_openai_conversation/helpers.py @@ -58,7 +58,10 @@ FunctionNotFound, ) -from .const import DOMAIN, EVENT_AUTOMATION_REGISTERED, DEFAULT_CONF_BASE_URL +from .const import ( + DOMAIN, + EVENT_AUTOMATION_REGISTERED, +) _LOGGER = logging.getLogger(__name__) @@ -80,6 +83,13 @@ def get_api_type(base_url: str): return None +def get_default_model_key(base_url: str): + is_azure = get_api_type(base_url) == "azure" + if is_azure: + return "engine" + return "model" + + def convert_to_template( settings, template_keys=["data", "event_data", "target", "service"],