Skip to content

Commit

Permalink
Use tools instead of functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
rkistner authored and jekalmin committed Jan 19, 2024
1 parent 3697240 commit e4b07b1
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 36 deletions.
80 changes: 74 additions & 6 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand All @@ -55,6 +56,7 @@
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_CONF_FUNCTIONS,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DOMAIN,
)

Expand Down Expand Up @@ -281,16 +283,24 @@ async def query(
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
use_tools = self.entry.options.get(CONF_USE_TOOLS, DEFAULT_USE_TOOLS)
functions = list(map(lambda s: s["spec"], self.get_functions()))
function_call = "auto"
if n_requests == self.entry.options.get(
CONF_MAX_FUNCTION_CALLS_PER_CONVERSATION,
DEFAULT_MAX_FUNCTION_CALLS_PER_CONVERSATION,
):
function_call = "none"

tool_kwargs = {"functions": functions, "function_call": function_call}
if use_tools:
tool_kwargs = {
"tools": [{"type": "function", "function": func} for func in functions],
"tool_choice": function_call,
}

if len(functions) == 0:
functions = None
function_call = None
tool_kwargs = {}

_LOGGER.info("Prompt for %s: %s", model, messages)

Expand All @@ -301,20 +311,24 @@ async def query(
top_p=top_p,
temperature=temperature,
user=user_input.conversation_id,
functions=functions,
function_call=function_call,
**tool_kwargs,
)

_LOGGER.info("Response %s", response.model_dump(exclude_none=True))
choice: Choice = response.choices[0]
message = choice.message

if choice.finish_reason == "function_call":
message = await self.execute_function_call(
user_input, messages, message, exposed_entities, n_requests + 1
)
if choice.finish_reason == "tool_calls":
message = await self.execute_tool_calls(
user_input, messages, message, exposed_entities, n_requests + 1
)
return message

def execute_function_call(
async def execute_function_call(
self,
user_input: conversation.ConversationInput,
messages,
Expand All @@ -328,7 +342,7 @@ def execute_function_call(
None,
)
if function is not None:
return self.execute_function(
return await self.execute_function(
user_input,
messages,
message,
Expand Down Expand Up @@ -366,3 +380,57 @@ async def execute_function(
}
)
return await self.query(user_input, messages, exposed_entities, n_requests)

async def execute_tool_calls(
self,
user_input: conversation.ConversationInput,
messages,
message: ChatCompletionMessage,
exposed_entities,
n_requests,
):
messages.append(message.model_dump(exclude_none=True))
for tool in message.tool_calls:
function_name = tool.function.name
function = next(
(s for s in self.get_functions() if s["spec"]["name"] == function_name),
None,
)
if function is not None:
result = await self.execute_tool_function(
user_input,
tool,
exposed_entities,
function,
)

messages.append(
{
"tool_call_id": tool.id,
"role": "tool",
"name": function_name,
"content": str(result),
}
)
else:
raise FunctionNotFound(function_name)
return await self.query(user_input, messages, exposed_entities, n_requests)

async def execute_tool_function(
self,
user_input: conversation.ConversationInput,
tool,
exposed_entities,
function,
):
function_executor = get_function_executor(function["function"]["type"])

try:
arguments = json.loads(tool.function.arguments)
except json.decoder.JSONDecodeError as err:
raise ParseArgumentsFailed(tool.function.arguments) from err

result = await function_executor.execute(
self.hass, function["function"], arguments, user_input, exposed_entities
)
return result
8 changes: 8 additions & 0 deletions custom_components/extended_openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
CONF_BASE_URL,
CONF_API_VERSION,
CONF_SKIP_AUTHENTICATION,
CONF_USE_TOOLS,
DEFAULT_ATTACH_USERNAME,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
Expand All @@ -50,6 +51,7 @@
DEFAULT_CONF_FUNCTIONS,
DEFAULT_CONF_BASE_URL,
DEFAULT_SKIP_AUTHENTICATION,
DEFAULT_USE_TOOLS,
DOMAIN,
DEFAULT_NAME,
)
Expand Down Expand Up @@ -80,6 +82,7 @@
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
CONF_FUNCTIONS: DEFAULT_CONF_FUNCTIONS_STR,
CONF_ATTACH_USERNAME: DEFAULT_ATTACH_USERNAME,
CONF_USE_TOOLS: DEFAULT_USE_TOOLS,
}
)

Expand Down Expand Up @@ -222,4 +225,9 @@ def openai_config_option_schema(self, options: MappingProxyType[str, Any]) -> di
description={"suggested_value": options.get(CONF_ATTACH_USERNAME)},
default=DEFAULT_ATTACH_USERNAME,
): BooleanSelector(),
vol.Optional(
CONF_USE_TOOLS,
description={"suggested_value": options.get(CONF_USE_TOOLS)},
default=DEFAULT_USE_TOOLS,
): BooleanSelector(),
}
2 changes: 2 additions & 0 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,7 @@
]
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False
CONF_USE_TOOLS = "use_tools"
DEFAULT_USE_TOOLS = False

SERVICE_QUERY_IMAGE = "query_image"
79 changes: 49 additions & 30 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ async def validate_authentication(
return

if is_azure(base_url):
client = AsyncAzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
client = AsyncAzureOpenAI(
api_key=api_key, azure_endpoint=base_url, api_version=api_version
)
else:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)

Expand Down Expand Up @@ -197,6 +199,10 @@ async def execute(
return await self.execute_service(
hass, function, arguments, user_input, exposed_entities
)
if name == "execute_service_single":
return await self.execute_service_single(
hass, function, arguments, user_input, exposed_entities
)
if name == "add_automation":
return await self.add_automation(
hass, function, arguments, user_input, exposed_entities
Expand All @@ -208,6 +214,44 @@ async def execute(

raise NativeNotFound(name)

async def execute_service_single(
self,
hass: HomeAssistant,
function,
service_argument,
user_input: conversation.ConversationInput,
exposed_entities,
):
domain = service_argument["domain"]
service = service_argument["service"]
service_data = service_argument.get(
"service_data", service_argument.get("data", {})
)
entity_id = service_data.get("entity_id", service_argument.get("entity_id"))
area_id = service_data.get("area_id")
device_id = service_data.get("device_id")

if isinstance(entity_id, str):
entity_id = [e.strip() for e in entity_id.split(",")]
service_data["entity_id"] = entity_id

if entity_id is None and area_id is None and device_id is None:
raise CallServiceError(domain, service, service_data)
if not hass.services.has_service(domain, service):
raise ServiceNotFound(domain, service)
self.validate_entity_ids(hass, entity_id or [], exposed_entities)

try:
await hass.services.async_call(
domain=domain,
service=service,
service_data=service_data,
)
return {"success": True}
except HomeAssistantError as e:
_LOGGER.error(e)
return {"error": str(e)}

async def execute_service(
self,
hass: HomeAssistant,
Expand All @@ -218,36 +262,11 @@ async def execute_service(
):
result = []
for service_argument in arguments.get("list", []):
domain = service_argument["domain"]
service = service_argument["service"]
service_data = service_argument.get(
"service_data", service_argument.get("data", {})
)
entity_id = service_data.get("entity_id", service_argument.get("entity_id"))
area_id = service_data.get("area_id")
device_id = service_data.get("device_id")

if isinstance(entity_id, str):
entity_id = [e.strip() for e in entity_id.split(",")]
service_data["entity_id"] = entity_id

if entity_id is None and area_id is None and device_id is None:
raise CallServiceError(domain, service, service_data)
if not hass.services.has_service(domain, service):
raise ServiceNotFound(domain, service)
self.validate_entity_ids(hass, entity_id or [], exposed_entities)

try:
await hass.services.async_call(
domain=domain,
service=service,
service_data=service_data,
result.append(
await self.execute_service_single(
hass, function, service_argument, user_input, exposed_entities
)
result.append(True)
except HomeAssistantError as e:
_LOGGER.error(e)
result.append(False)

)
return result

async def add_automation(
Expand Down

0 comments on commit e4b07b1

Please sign in to comment.