Skip to content

Commit

Permalink
[#42] fix default model key
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Dec 24, 2023
1 parent 9e0b346 commit 6dc8810
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
10 changes: 6 additions & 4 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_MODEL_KEY,
DOMAIN,
)

Expand Down Expand Up @@ -79,6 +78,7 @@
validate_authentication,
get_function_executor,
get_api_type,
get_default_model_key,
)


Expand Down Expand Up @@ -284,9 +284,11 @@ async def query(
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
):
function_call = "none"
model_kwargs = {
self.entry.options.get(CONF_MODEL_KEY, DEFAULT_MODEL_KEY): model
}

model_key = self.entry.options.get(
CONF_MODEL_KEY, get_default_model_key(api_base)
)
model_kwargs = {model_key: model}

_LOGGER.info("Prompt for %s: %s", model, messages)

Expand Down
13 changes: 5 additions & 8 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
SelectSelectorMode,
)

from .helpers import validate_authentication, get_api_type
from .helpers import validate_authentication, get_default_model_key

from .const import (
CONF_ATTACH_USERNAME,
Expand All @@ -52,7 +52,6 @@
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONF_BASE_URL,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_MODEL_KEY,
DOMAIN,
DEFAULT_NAME,
)
Expand Down Expand Up @@ -179,8 +178,6 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
if not options:
options = DEFAULT_OPTIONS

is_azure = get_api_type(self.config_entry.data.get(CONF_BASE_URL)) == "azure"

return {
vol.Optional(
CONF_PROMPT,
Expand Down Expand Up @@ -224,15 +221,15 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
): TemplateSelector(),
vol.Optional(
CONF_ATTACH_USERNAME,
description={
"suggested_value": options.get(CONF_ATTACH_USERNAME)
},
description={"suggested_value": options.get(CONF_ATTACH_USERNAME)},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
vol.Optional(
CONF_MODEL_KEY,
description={"suggested_value": options.get(CONF_MODEL_KEY)},
default="engine" if is_azure else DEFAULT_MODEL_KEY,
default=get_default_model_key(
self.config_entry.data.get(CONF_BASE_URL)
),
): SelectSelector(
SelectSelectorConfig(
options=[
Expand Down
1 change: 0 additions & 1 deletion custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,4 @@
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False
CONF_MODEL_KEY = "model_key"
DEFAULT_MODEL_KEY = "model"
MODEL_KEYS = ["model", "engine"]
12 changes: 11 additions & 1 deletion custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
FunctionNotFound,
)

from .const import DOMAIN, EVENT_AUTOMATION_REGISTERED, DEFAULT_CONF_BASE_URL
from .const import (
DOMAIN,
EVENT_AUTOMATION_REGISTERED,
)


_LOGGER = logging.getLogger(__name__)
Expand All @@ -80,6 +83,13 @@ def get_api_type(base_url: str):
return None


def get_default_model_key(base_url: str):
is_azure = get_api_type(base_url) == "azure"
if is_azure:
return "engine"
return "model"


def convert_to_template(
settings,
template_keys=["data", "event_data", "target", "service"],
Expand Down

0 comments on commit 6dc8810

Please sign in to comment.