Skip to content

Commit

Permalink
[#86] "area_id" and "deviced_id" can be set to execute_service
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin committed Jan 13, 2024
1 parent 094ce16 commit cfd3b37
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 29 deletions.
18 changes: 2 additions & 16 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
ServiceNotFound,
)

from homeassistant.helpers import (
Expand Down Expand Up @@ -59,26 +58,13 @@
)

from .exceptions import (
EntityNotFound,
EntityNotExposed,
CallServiceError,
FunctionNotFound,
NativeNotFound,
FunctionLoadFailed,
ParseArgumentsFailed,
InvalidFunction,
)

from .helpers import (
FUNCTION_EXECUTORS,
FunctionExecutor,
NativeFunctionExecutor,
ScriptFunctionExecutor,
TemplateFunctionExecutor,
RestFunctionExecutor,
ScrapeFunctionExecutor,
CompositeFunctionExecutor,
convert_to_template,
validate_authentication,
get_function_executor,
is_azure,
Expand All @@ -88,13 +74,13 @@
_LOGGER = logging.getLogger(__name__)

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
AZURE_DOMAIN_PATTERN = r"\.openai\.azure\.com"


# hass.data key for agent.
DATA_AGENT = "agent"



async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry."""

Expand Down Expand Up @@ -306,7 +292,7 @@ async def query(
)


_LOGGER.info("Response %s", response)
_LOGGER.info("Response %s", response.model_dump(exclude_none=True))
choice: Choice = response.choices[0]
message = choice.message
if choice.finish_reason == "function_call":
Expand Down
4 changes: 2 additions & 2 deletions custom_components/extended_openai_conversation/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __init__(self, domain: str, service: str, data: object) -> None:
"""Initialize error."""
super().__init__(
self,
f"unable to call service {domain}.{service} with data {data}. 'entity_id' is required",
f"unable to call service {domain}.{service} with data {data}. One of 'entity_id', 'area_id', or 'device_id' is required",
)
self.domain = domain
self.service = service
self.data = data

def __str__(self) -> str:
"""Return string representation."""
return f"unable to call service {self.domain}.{self.service} with data {self.data}. 'entity_id' is required"
return f"unable to call service {self.domain}.{self.service} with data {self.data}. One of 'entity_id', 'area_id', or 'device_id' is required"


class FunctionNotFound(HomeAssistantError):
Expand Down
17 changes: 6 additions & 11 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
automation,
rest,
scrape,
history,
conversation,
recorder,
)
Expand All @@ -38,13 +37,7 @@
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.template import Template
from homeassistant.helpers.script import (
Script,
SCRIPT_MODE_SINGLE,
SCRIPT_MODE_PARALLEL,
DEFAULT_MAX,
DEFAULT_MAX_EXCEEDED,
)
from homeassistant.helpers.script import Script
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound


Expand Down Expand Up @@ -231,16 +224,18 @@ async def execute_service(
"service_data", service_argument.get("data", {})
)
entity_id = service_data.get("entity_id", service_argument.get("entity_id"))
area_id = service_data.get("area_id")
device_id = service_data.get("device_id")

if isinstance(entity_id, str):
entity_id = [e.strip() for e in entity_id.split(",")]
service_data["entity_id"] = entity_id

if entity_id is None:
if entity_id is None and area_id is None and device_id is None:
raise CallServiceError(domain, service, service_data)
if not hass.services.has_service(domain, service):
raise ServiceNotFound(domain, service)
self.validate_entity_ids(hass, entity_id, exposed_entities)
self.validate_entity_ids(hass, entity_id or [], exposed_entities)

try:
await hass.services.async_call(
Expand All @@ -249,7 +244,7 @@ async def execute_service(
service_data=service_data,
)
result.append(True)
except HomeAssistantError:
except HomeAssistantError as e:
_LOGGER.error(e)
result.append(False)

Expand Down

0 comments on commit cfd3b37

Please sign in to comment.