Skip to content

Commit

Permalink
agent tools
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Oct 30, 2024
1 parent 7701473 commit 3f4f107
Showing 1 changed file with 51 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,43 +181,60 @@ impl GenericInferenceChain {
);

if use_tools {
if let Some(tool_router) = &tool_router {
// TODO: enable back the default tools (must tools)
// // Get default tools
// if let Ok(default_tools) = tool_router.get_default_tools(&user_profile) {
// tools.extend(default_tools);
// }

// Search in JS Tools
let results = tool_router
.vector_search_enabled_tools_with_network(&user_message.clone(), 5)
.await;

match results {
Ok(results) => {
for result in results {
match tool_router.get_tool_by_name(&result.tool_router_key).await {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => {
return Err(LLMProviderError::ToolNotFound(format!(
"Tool not found for key: {}",
result.tool_router_key
)));
}
Err(e) => {
return Err(LLMProviderError::ToolRetrievalError(format!(
"Error retrieving tool: {:?}",
e
)));
}
// If the llm_provider is an Agent, retrieve tools directly from the Agent struct
if let ProviderOrAgent::Agent(agent) = &llm_provider {
for tool_name in &agent.tools {
if let Some(tool_router) = &tool_router {
match tool_router.get_tool_by_name(tool_name).await {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => {
return Err(LLMProviderError::ToolNotFound(format!(
"Tool not found for name: {}",
tool_name
)));
}
Err(e) => {
return Err(LLMProviderError::ToolRetrievalError(format!(
"Error retrieving tool: {:?}",
e
)));
}
}
}
Err(e) => {
return Err(LLMProviderError::ToolSearchError(format!(
"Error during tool search: {:?}",
e
)));
}
} else {
// If the llm_provider is not an Agent, perform a vector search for tools
if let Some(tool_router) = &tool_router {
let results = tool_router
.vector_search_enabled_tools_with_network(&user_message.clone(), 5)
.await;

match results {
Ok(results) => {
for result in results {
match tool_router.get_tool_by_name(&result.tool_router_key).await {
Ok(Some(tool)) => tools.push(tool),
Ok(None) => {
return Err(LLMProviderError::ToolNotFound(format!(
"Tool not found for key: {}",
result.tool_router_key
)));
}
Err(e) => {
return Err(LLMProviderError::ToolRetrievalError(format!(
"Error retrieving tool: {:?}",
e
)));
}
}
}
}
Err(e) => {
return Err(LLMProviderError::ToolSearchError(format!(
"Error during tool search: {:?}",
e
)));
}
}
}
}
Expand Down

0 comments on commit 3f4f107

Please sign in to comment.