From 3f4f107ae30da61e86c5eac619f0f345ea9c6d7a Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Tue, 29 Oct 2024 23:39:00 -0500 Subject: [PATCH] agent tools --- .../generic_chain/generic_inference_chain.rs | 85 +++++++++++-------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs index 112c83f39..fb3b1298d 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs @@ -181,43 +181,60 @@ impl GenericInferenceChain { ); if use_tools { - if let Some(tool_router) = &tool_router { - // TODO: enable back the default tools (must tools) - // // Get default tools - // if let Ok(default_tools) = tool_router.get_default_tools(&user_profile) { - // tools.extend(default_tools); - // } - - // Search in JS Tools - let results = tool_router - .vector_search_enabled_tools_with_network(&user_message.clone(), 5) - .await; - - match results { - Ok(results) => { - for result in results { - match tool_router.get_tool_by_name(&result.tool_router_key).await { - Ok(Some(tool)) => tools.push(tool), - Ok(None) => { - return Err(LLMProviderError::ToolNotFound(format!( - "Tool not found for key: {}", - result.tool_router_key - ))); - } - Err(e) => { - return Err(LLMProviderError::ToolRetrievalError(format!( - "Error retrieving tool: {:?}", - e - ))); - } + // If the llm_provider is an Agent, retrieve tools directly from the Agent struct + if let ProviderOrAgent::Agent(agent) = &llm_provider { + for tool_name in &agent.tools { + if let Some(tool_router) = &tool_router { + match tool_router.get_tool_by_name(tool_name).await { + Ok(Some(tool)) => tools.push(tool), + Ok(None) => { + return Err(LLMProviderError::ToolNotFound(format!( + "Tool not found for name: {}", + tool_name + ))); + } + Err(e) => { + return Err(LLMProviderError::ToolRetrievalError(format!( + "Error retrieving tool: {:?}", + e + ))); } } } - Err(e) => { - return Err(LLMProviderError::ToolSearchError(format!( - "Error during tool search: {:?}", - e - ))); + } + } else { + // If the llm_provider is not an Agent, perform a vector search for tools + if let Some(tool_router) = &tool_router { + let results = tool_router + .vector_search_enabled_tools_with_network(&user_message.clone(), 5) + .await; + + match results { + Ok(results) => { + for result in results { + match tool_router.get_tool_by_name(&result.tool_router_key).await { + Ok(Some(tool)) => tools.push(tool), + Ok(None) => { + return Err(LLMProviderError::ToolNotFound(format!( + "Tool not found for key: {}", + result.tool_router_key + ))); + } + Err(e) => { + return Err(LLMProviderError::ToolRetrievalError(format!( + "Error retrieving tool: {:?}", + e + ))); + } + } + } + } + Err(e) => { + return Err(LLMProviderError::ToolSearchError(format!( + "Error during tool search: {:?}", + e + ))); + } } } }