Skip to content

Commit

Permalink
Complete implementation of AI Tools #41
Browse files Browse the repository at this point in the history
Use Google Search, Bing search, Tavily web search for web search (e.g. can you give me some articles to read about anxiety in the workplace?)
 Use Google places for maps search (e.g. Where are some therapists in my area?)

added youtube too
  • Loading branch information
dhrumilp12 committed Sep 8, 2024
1 parent 8758661 commit 99e9c10
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 35 deletions.
8 changes: 7 additions & 1 deletion server/agents/mental_health_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from langchain_core.messages import trim_messages
from langchain_core.messages.human import HumanMessage


# MongoDB
# -- Custom modules --
from .ai_agent import AIAgent
Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(self, system_message: str = SYSTEM_MESSAGE, tool_names: list[str] =
"""
super().__init__(system_message, tool_names)



self.prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_message.content),
Expand Down Expand Up @@ -273,7 +276,10 @@ def run(self, message: str, with_history:bool =True, user_id: str=None, chat_id:
{"input": message, "user_id": user_id, "agent_scratchpad": []},
config={"configurable": {"session_id": session_id}})

return invocation["output"]
response = invocation["output"]
if isinstance(response, dict):
response = {k: (v if isinstance(v, (str, int, float, bool, list, dict)) else str(v)) for k, v in response.items()}
return response


def get_initial_greeting(self, user_id:str) -> dict:
Expand Down
52 changes: 23 additions & 29 deletions server/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,14 @@
import os

# from azure.identity import DefaultAzureCredential
# from azure.mgmt.maps import AzureMapsManagementClient

from langchain_google_community import GooglePlacesTool, GooglePlacesAPIWrapper

from langchain_community.utilities import BingSearchAPIWrapper
from langchain_community.tools.bing_search import BingSearchResults
from langchain_community.tools.tavily_search import TavilySearchResults

from utils.docs import format_docs

from services.db.user import get_user_profile_by_user_id
from langchain.tools import Tool
from utils.agents import get_google_search_results, get_bing_search_results, get_youtube_search_results
from langchain_google_community import GooglePlacesTool, GooglePlacesAPIWrapper


def get_maps_results(self, query):
"""
Searches for places that match the query and returns a list of results.
Args:
query (str): The query to search for in the map.
"""

# sub_id = os.environ.get("AZURE_SUBSCRIPTION_ID")
# client = AzureMapsManagementClient(credential=DefaultAzureCredential(), subscription_id=sub_id)

pass



def get_vector_store_chain(agent, collection_name:str):
Expand All @@ -38,11 +20,11 @@ def vector_store_chain_factory(collection_name) -> callable:
return lambda x: get_vector_store_chain(collection_name=collection_name)



toolbox = {
"community": {
"web_search_bing": BingSearchResults(api_wrapper=BingSearchAPIWrapper(k=1)),
"web_search_tavily": TavilySearchResults(),
"location_search_gplaces": GooglePlacesTool()
"location_search_gplaces": GooglePlacesTool(),
},
"custom": {
"agent_facts": {
Expand All @@ -51,12 +33,24 @@ def vector_store_chain_factory(collection_name) -> callable:
"retriever": True,
"structured": False
},
# "location_search": {
# "func": get_maps_results,
# "description": "Searches for places that match the query and returns a list of results.",
# "retriever": False,
# "structured": False
# },
"web_search_bing": {
"func": get_bing_search_results,
"description": "Uses Google Custom Search to fetch search results for a given query.",
"retriever": False,
"structured": True
},
"web_search_google": {
"func": get_google_search_results,
"description": "Uses Google Custom Search to fetch search results for a given query.",
"retriever": False,
"structured": True
},
"web_search_youtube": {
"func": get_youtube_search_results,
"description": "Uses YouTube Search to fetch search results for a given query.",
"retriever": False,
"structured": True
},
"user_profile_retrieval": {
"func": get_user_profile_by_user_id,
"structured": True,
Expand Down
6 changes: 3 additions & 3 deletions server/routes/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@ai_routes.post("/ai/mental_health/welcome/<user_id>")
def get_mental_health_agent_welcome(user_id):
agent = MentalHealthAIAgent(tool_names=["location_search", "web_search_bing", "user_profile_retrieval", "agent_facts"])
agent = MentalHealthAIAgent(tool_names=["web_search_youtube","web_search_tavily","wiki_search","web_search_bing","location_search_gplaces", "web_search_google", "user_profile_retrieval", "agent_facts"])

response = agent.get_initial_greeting(
user_id=user_id
Expand All @@ -36,7 +36,7 @@ def run_mental_health_agent(user_id, chat_id):
prompt = body.get("prompt")
turn_id = body.get("turn_id")

agent = MentalHealthAIAgent(tool_names=["location_search", "web_search_bing", "user_profile_retrieval", "agent_facts"])
agent = MentalHealthAIAgent(tool_names=["web_search_youtube","web_search_google","wiki_search","web_search_tavily","location_search_gplaces", "web_search_bing", "user_profile_retrieval", "agent_facts"])

try:

Expand All @@ -61,7 +61,7 @@ def run_mental_health_agent(user_id, chat_id):
def set_mental_health_end_state(user_id, chat_id):
try:
logger.info(f"Finalizing chat {chat_id} for user {user_id}")
agent = MentalHealthAIAgent(tool_names=["location_search", "web_search_bing", "user_profile_retrieval", "agent_facts"])
agent = MentalHealthAIAgent(tool_names=["web_search_youtube","web_search_tavily","wiki_search","web_search_bing","location_search_gplaces", "web_search_google", "user_profile_retrieval", "agent_facts"])

agent.perform_final_processes(user_id, chat_id)

Expand Down
4 changes: 2 additions & 2 deletions server/routes/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

load_dotenv()

GOOGLE_CUSTOME_SEARCH_API_KEY = os.getenv('GOOGLE_CUSTOME_SEARCH_API_KEY')
GOOGLE_SEARCH_CSE_ID = os.getenv('GOOGLE_SEARCH_CSE_ID')
GOOGLE_CUSTOME_SEARCH_API_KEY = os.getenv('GOOGLE_API_KEY')
GOOGLE_SEARCH_CSE_ID = os.getenv('GOOGLE_CSE_ID')
YOUTUBE_API_KEY = os.getenv('YOUTUBE_API_KEY')

@search_routes.get('/search')
Expand Down
75 changes: 75 additions & 0 deletions server/utils/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
""" This module contains the agent functions that interact with the external APIs. """

from langchain_google_community import GoogleSearchAPIWrapper, GoogleSearchResults, GooglePlacesTool, GooglePlacesAPIWrapper
from langchain_community.utilities import BingSearchAPIWrapper
from langchain_community.tools import YouTubeSearchTool

def get_google_search_results(query):
"""
Uses Google Custom Search to fetch search results for a given query.
Args:
query (str): The search query.
Returns:
list: A list of search results with titles and links.
"""

try:
google_search_wrapper = GoogleSearchAPIWrapper(k=3)
search_results = google_search_wrapper.run(query)
print("Search results obtained:", search_results)

# Ensure the results are JSON-serializable

return search_results

except Exception as e:
print(f"Failed to fetch Google search results: {e}")
return None

def get_youtube_search_results(query):
"""
Uses YouTube Search to fetch search results for a given query.
Args:
query (str): The search query.
Returns:
list: A list of search results with titles, descriptions, and video links.
"""
try:
youtube_search_tool = YouTubeSearchTool()
search_results = youtube_search_tool.run(query)
print("Search results obtained:", search_results)

# Ensure the results are JSON-serializable
return search_results

except Exception as e:
print(f"Failed to fetch YouTube search results: {e}")
return None


def get_bing_search_results(query):
"""
Uses Bing Search to fetch search results for a given query.
Args:
query (str): The search query.
Returns:
list: A list of search results with titles and links.
"""
try:
bing_search_wrapper = BingSearchAPIWrapper()
search_results = bing_search_wrapper.run(query)
print("Search results obtained:", search_results)

# Ensure the results are JSON-serializable
return search_results

except Exception as e:
print(f"Failed to fetch Bing search results: {e}")
return None

0 comments on commit 99e9c10

Please sign in to comment.