Skip to content

Commit

Permalink
Format Python code with psf/black push
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions authored and github-actions committed Feb 6, 2023
1 parent 1f40020 commit 42822e8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
18 changes: 13 additions & 5 deletions cogs/search_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ async def paginate_embed(self, response_text):

return pages

async def search_command(self, ctx: discord.ApplicationContext, query, search_scope, nodes):
async def search_command(
self, ctx: discord.ApplicationContext, query, search_scope, nodes
):
"""Command handler for the translation command"""
user_api_key = None
if USER_INPUT_API_KEYS:
Expand All @@ -79,10 +81,15 @@ async def search_command(self, ctx: discord.ApplicationContext, query, search_sc
try:
response = await self.model.search(query, user_api_key, search_scope, nodes)
except ValueError:
await ctx.respond("The Google Search API returned an error. Check the console for more details.", ephemeral=True)
await ctx.respond(
"The Google Search API returned an error. Check the console for more details.",
ephemeral=True,
)
return
except Exception:
await ctx.respond("An error occurred. Check the console for more details.", ephemeral=True)
await ctx.respond(
"An error occurred. Check the console for more details.", ephemeral=True
)
traceback.print_exc()
return

Expand All @@ -95,7 +102,9 @@ async def search_command(self, ctx: discord.ApplicationContext, query, search_sc
urls = "\n".join(f"<{url}>" for url in urls)

query_response_message = f"**Query:**`\n\n{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}\n\n**Sources:**\n{urls}"
query_response_message = query_response_message.replace("<|endofstatement|>", "")
query_response_message = query_response_message.replace(
"<|endofstatement|>", ""
)

# If the response is too long, lets paginate using the discord pagination
# helper
Expand All @@ -107,4 +116,3 @@ async def search_command(self, ctx: discord.ApplicationContext, query, search_sc
)

await paginator.respond(ctx.interaction)

30 changes: 20 additions & 10 deletions models/index_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ async def index_web_pdf(self, url, embed_model) -> GPTSimpleVectorIndex:
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
return index


def index_gdoc(self, doc_id, embed_model) -> GPTSimpleVectorIndex:
document = GoogleDocsReader().load_data(doc_id)
index = GPTSimpleVectorIndex(document, embed_model=embed_model)
Expand Down Expand Up @@ -304,9 +303,6 @@ async def set_link_index(
await ctx.respond("Failed to get link", ephemeral=True)
return




# Check if the link contains youtube in it
if "youtube" in link:
index = await self.loop.run_in_executor(
Expand Down Expand Up @@ -415,11 +411,19 @@ async def compose_indexes(self, user_id, indexes, name, deep_compose):
for doc_id in [docmeta for docmeta in _index.docstore.docs.keys()]
if isinstance(_index.docstore.get_document(doc_id), Document)
]
llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003", max_tokens=-1))
llm_predictor = LLMPredictor(
llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)
)
embedding_model = OpenAIEmbedding()

tree_index = await self.loop.run_in_executor(
None, partial(GPTTreeIndex, documents=documents, llm_predictor=llm_predictor, embed_model=embedding_model)
None,
partial(
GPTTreeIndex,
documents=documents,
llm_predictor=llm_predictor,
embed_model=embedding_model,
),
)

await self.usage_service.update_usage(llm_predictor.last_token_usage)
Expand Down Expand Up @@ -449,7 +453,12 @@ async def compose_indexes(self, user_id, indexes, name, deep_compose):
embedding_model = OpenAIEmbedding()

simple_index = await self.loop.run_in_executor(
None, partial(GPTSimpleVectorIndex, documents=documents, embed_model=embedding_model)
None,
partial(
GPTSimpleVectorIndex,
documents=documents,
embed_model=embedding_model,
),
)

await self.usage_service.update_usage(
Expand Down Expand Up @@ -533,8 +542,10 @@ async def query(
await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True
)
query_response_message=f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}"
query_response_message = query_response_message.replace("<|endofstatement|>", "")
query_response_message = f"**Query:**\n\n`{query.strip()}`\n\n**Query response:**\n\n{response.response.strip()}"
query_response_message = query_response_message.replace(
"<|endofstatement|>", ""
)
embed_pages = await self.paginate_embed(query_response_message)
paginator = pages.Paginator(
pages=embed_pages,
Expand Down Expand Up @@ -763,7 +774,6 @@ async def interaction_check(self, interaction: discord.Interaction) -> bool:
except discord.Forbidden:
pass


try:
await composing_message.delete()
except:
Expand Down
37 changes: 28 additions & 9 deletions models/search_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
Document,
PromptHelper,
LLMPredictor,
OpenAIEmbedding, SimpleDirectoryReader,
OpenAIEmbedding,
SimpleDirectoryReader,
)
from gpt_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR
from langchain import OpenAI
Expand Down Expand Up @@ -50,7 +51,6 @@ def index_webpage(self, url) -> list[Document]:
).load_data(urls=[url])
return documents


async def index_pdf(self, url) -> list[Document]:
# Download the PDF at the url and save it to a tempfile
async with aiohttp.ClientSession() as session:
Expand Down Expand Up @@ -79,11 +79,15 @@ async def get_links(self, query, search_scope=2):
if response.status == 200:
data = await response.json()
# Return a list of the top 2 links
return ([item["link"] for item in data["items"][:search_scope]], [
item["link"] for item in data["items"]
])
return (
[item["link"] for item in data["items"][:search_scope]],
[item["link"] for item in data["items"]],
)
else:
print("The Google Search API returned an error: " + str(response.status))
print(
"The Google Search API returned an error: "
+ str(response.status)
)
return ["An error occurred while searching.", None]

async def search(self, query, user_api_key, search_scope, nodes):
Expand Down Expand Up @@ -157,17 +161,32 @@ async def search(self, query, user_api_key, search_scope, nodes):

embedding_model = OpenAIEmbedding()

index = await self.loop.run_in_executor(None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model))
index = await self.loop.run_in_executor(
None, partial(GPTSimpleVectorIndex, documents, embed_model=embedding_model)
)

await self.usage_service.update_usage(
embedding_model.last_token_usage, embeddings=True
)

llm_predictor = LLMPredictor(llm=OpenAI(model_name="text-davinci-003", max_tokens=-1))
llm_predictor = LLMPredictor(
llm=OpenAI(model_name="text-davinci-003", max_tokens=-1)
)
# Now we can search the index for a query:
embedding_model.last_token_usage = 0

response = await self.loop.run_in_executor(None, partial(index.query, query, verbose=True, embed_model=embedding_model, llm_predictor=llm_predictor, similarity_top_k=nodes or DEFAULT_SEARCH_NODES, text_qa_template=self.qaprompt))
response = await self.loop.run_in_executor(
None,
partial(
index.query,
query,
verbose=True,
embed_model=embedding_model,
llm_predictor=llm_predictor,
similarity_top_k=nodes or DEFAULT_SEARCH_NODES,
text_qa_template=self.qaprompt,
),
)

await self.usage_service.update_usage(llm_predictor.last_token_usage)
await self.usage_service.update_usage(
Expand Down

0 comments on commit 42822e8

Please sign in to comment.