Skip to content

Commit

Permalink
[#43] Add "query_image" service
Browse files Browse the repository at this point in the history
  • Loading branch information
jekalmin authored and jekalmin committed Jan 17, 2024
1 parent 1ee4861 commit 0f3d847
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 65 deletions.
74 changes: 12 additions & 62 deletions custom_components/extended_openai_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Literal
import json
import yaml
import voluptuous as vol

from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import (
Expand All @@ -18,12 +17,7 @@
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL, ATTR_NAME
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
Expand All @@ -38,7 +32,6 @@
intent,
template,
entity_registry as er,
selector,
)

from .const import (
Expand Down Expand Up @@ -78,6 +71,8 @@
is_azure,
)

from .services import async_setup_services


_LOGGER = logging.getLogger(__name__)

Expand All @@ -86,64 +81,14 @@

# hass.data key for agent.
DATA_AGENT = "agent"
SERVICE_QUERY_IMAGE = "query_image"


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation."""

async def query_image(call: ServiceCall) -> ServiceResponse:
"""Query an image."""
try:
model = call.data["model"]
images = [
{"type": "image_url", "image_url": image}
for image in call.data["images"]
]

messages = [
{
"role": "user",
"content": [{"type": "text", "text": call.data["prompt"]}] + images,
}
]
_LOGGER.info("Prompt for %s: %s", model, messages)

response = await openai.ChatCompletion.acreate(
api_key=hass.data[DOMAIN][call.data["config_entry"]]["api_key"],
model=model,
messages=messages,
max_tokens=call.data["max_tokens"],
)
_LOGGER.info("Response %s", response)
except error.OpenAIError as err:
raise HomeAssistantError(f"Error generating image: {err}") from err

return response

hass.services.async_register(
DOMAIN,
SERVICE_QUERY_IMAGE,
query_image,
schema=vol.Schema(
{
vol.Required("config_entry"): selector.ConfigEntrySelector(
{
"integration": DOMAIN,
}
),
vol.Required("model", default="gpt-4-vision-preview"): cv.string,
vol.Required("prompt"): cv.string,
vol.Required("images"): vol.All(cv.ensure_list, [{"url": cv.url}]),
vol.Optional("max_tokens", default=300): cv.positive_int,
}
),
supports_response=SupportsResponse.ONLY,
)
await async_setup_services(hass, config)
return True



async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry."""

Expand Down Expand Up @@ -190,9 +135,15 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
self.history: dict[str, list[dict]] = {}
base_url = entry.data.get(CONF_BASE_URL)
if is_azure(base_url):
self.client = AsyncAzureOpenAI(api_key=entry.data[CONF_API_KEY], azure_endpoint=base_url, api_version=entry.data.get(CONF_API_VERSION))
self.client = AsyncAzureOpenAI(
api_key=entry.data[CONF_API_KEY],
azure_endpoint=base_url,
api_version=entry.data.get(CONF_API_VERSION),
)
else:
self.client = AsyncOpenAI(api_key=entry.data[CONF_API_KEY], base_url=base_url)
self.client = AsyncOpenAI(
api_key=entry.data[CONF_API_KEY], base_url=base_url
)

@property
def supported_languages(self) -> list[str] | Literal["*"]:
Expand Down Expand Up @@ -354,7 +305,6 @@ async def query(
function_call=function_call,
)


_LOGGER.info("Response %s", response.model_dump(exclude_none=True))
choice: Choice = response.choices[0]
message = choice.message
Expand Down
2 changes: 2 additions & 0 deletions custom_components/extended_openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@
]
CONF_ATTACH_USERNAME = "attach_username"
DEFAULT_ATTACH_USERNAME = False

SERVICE_QUERY_IMAGE = "query_image"
76 changes: 76 additions & 0 deletions custom_components/extended_openai_conversation/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

import voluptuous as vol
from openai import AsyncOpenAI
from openai._exceptions import OpenAIError

from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers import selector, config_validation as cv

from .const import DOMAIN, SERVICE_QUERY_IMAGE

QUERY_IMAGE_SCHEMA = vol.Schema(
{
vol.Required("config_entry"): selector.ConfigEntrySelector(
{
"integration": DOMAIN,
}
),
vol.Required("model", default="gpt-4-vision-preview"): cv.string,
vol.Required("prompt"): cv.string,
vol.Required("images"): vol.All(cv.ensure_list, [{"url": cv.url}]),
vol.Optional("max_tokens", default=300): cv.positive_int,
}
)

_LOGGER = logging.getLogger(__package__)


async def async_setup_services(hass: HomeAssistant, config: ConfigType) -> None:
"""Set up services for the extended openai conversation component."""

async def query_image(call: ServiceCall) -> ServiceResponse:
"""Query an image."""
try:
model = call.data["model"]
images = [
{"type": "image_url", "image_url": image}
for image in call.data["images"]
]

messages = [
{
"role": "user",
"content": [{"type": "text", "text": call.data["prompt"]}] + images,
}
]
_LOGGER.info("Prompt for %s: %s", model, messages)

response = await AsyncOpenAI(
api_key=hass.data[DOMAIN][call.data["config_entry"]]["api_key"]
).chat.completions.create(
model=model,
messages=messages,
max_tokens=call.data["max_tokens"],
)
response_dict = response.model_dump()
_LOGGER.info("Response %s", response_dict)
except OpenAIError as err:
raise HomeAssistantError(f"Error generating image: {err}") from err

return response_dict

hass.services.async_register(
DOMAIN,
SERVICE_QUERY_IMAGE,
query_image,
schema=QUERY_IMAGE_SCHEMA,
supports_response=SupportsResponse.ONLY,
)
32 changes: 32 additions & 0 deletions custom_components/extended_openai_conversation/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,37 @@
}
}
}
},
"services": {
"query_image": {
"name": "Query image",
"description": "Take in images and answer questions about them",
"fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service"
},
"model": {
"name": "Model",
"description": "The model",
"example": "gpt-4-vision-preview"
},
"prompt": {
"name": "Prompt",
"description": "The text to ask about image",
"example": "What’s in this image?"
},
"images": {
"name": "Images",
"description": "A list of images that would be asked",
"example": "{\"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"}"
},
"max_tokens": {
"name": "Max Tokens",
"description": "The maximum tokens",
"example": "300"
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,37 @@
}
}
}
},
"services": {
"query_image": {
"name": "Query image",
"description": "Take in images and answer questions about them",
"fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service"
},
"model": {
"name": "Model",
"description": "The model",
"example": "gpt-4-vision-preview"
},
"prompt": {
"name": "Prompt",
"description": "The text to ask about image",
"example": "What’s in this image?"
},
"images": {
"name": "Images",
"description": "A list of images that would be asked",
"example": "{\"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"}"
},
"max_tokens": {
"name": "Max Tokens",
"description": "The maximum tokens",
"example": "300"
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,37 @@
}
}
}
},
"services": {
"query_image": {
"name": "Query image",
"description": "Take in images and answer questions about them",
"fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service"
},
"model": {
"name": "Model",
"description": "The model",
"example": "gpt-4-vision-preview"
},
"prompt": {
"name": "Prompt",
"description": "The text to ask about image",
"example": "What’s in this image?"
},
"images": {
"name": "Images",
"description": "A list of images that would be asked",
"example": "{\"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"}"
},
"max_tokens": {
"name": "Max Tokens",
"description": "The maximum tokens",
"example": "300"
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,37 @@
}
}
}
},
"services": {
"query_image": {
"name": "Query image",
"description": "Take in images and answer questions about them",
"fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service"
},
"model": {
"name": "Model",
"description": "The model",
"example": "gpt-4-vision-preview"
},
"prompt": {
"name": "Prompt",
"description": "The text to ask about image",
"example": "What’s in this image?"
},
"images": {
"name": "Images",
"description": "A list of images that would be asked",
"example": "{\"url\": \"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg\"}"
},
"max_tokens": {
"name": "Max Tokens",
"description": "The maximum tokens",
"example": "300"
}
}
}
}
}
}
Loading

0 comments on commit 0f3d847

Please sign in to comment.