diff --git a/docs/openapi/general.yaml b/docs/openapi/general.yaml index c10c96a1c..0f8582878 100644 --- a/docs/openapi/general.yaml +++ b/docs/openapi/general.yaml @@ -763,16 +763,8 @@ components: required: - id - full_identity_name - - perform_locally - model - - toolkit_permissions - - storage_bucket_permissions - - allowed_message_senders properties: - allowed_message_senders: - type: array - items: - type: string api_key: type: string nullable: true @@ -785,16 +777,6 @@ components: type: string model: $ref: '#/components/schemas/LLMProviderInterface' - perform_locally: - type: boolean - storage_bucket_permissions: - type: array - items: - type: string - toolkit_permissions: - type: array - items: - type: string ShinkaiBackend: type: object required: diff --git a/docs/openapi/jobs.yaml b/docs/openapi/jobs.yaml index 05072c9bb..c68b5b8e1 100644 --- a/docs/openapi/jobs.yaml +++ b/docs/openapi/jobs.yaml @@ -1412,16 +1412,8 @@ components: required: - id - full_identity_name - - perform_locally - model - - toolkit_permissions - - storage_bucket_permissions - - allowed_message_senders properties: - allowed_message_senders: - type: array - items: - type: string api_key: type: string nullable: true @@ -1434,16 +1426,6 @@ components: type: string model: $ref: '#/components/schemas/LLMProviderInterface' - perform_locally: - type: boolean - storage_bucket_permissions: - type: array - items: - type: string - toolkit_permissions: - type: array - items: - type: string SheetJobAction: type: object required: diff --git a/shinkai-bin/shinkai-node/src/llm_provider/error.rs b/shinkai-bin/shinkai-node/src/llm_provider/error.rs index b9fa0d42a..00cffa35c 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/error.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/error.rs @@ -87,6 +87,7 @@ pub enum LLMProviderError { ToolNotFound(String), ToolRetrievalError(String), ToolSearchError(String), + AgentNotFound(String), } impl fmt::Display for LLMProviderError { @@ -176,6 +177,7 @@ impl fmt::Display for LLMProviderError { LLMProviderError::ToolNotFound(s) => write!(f, "Tool not found: {}", s), LLMProviderError::ToolRetrievalError(s) => write!(f, "Tool retrieval error: {}", s), LLMProviderError::ToolSearchError(s) => write!(f, "Tool search error: {}", s), + LLMProviderError::AgentNotFound(s) => write!(f, "Agent not found: {}", s), } } } @@ -255,6 +257,7 @@ impl LLMProviderError { LLMProviderError::ToolNotFound(_) => "ToolNotFound", LLMProviderError::ToolRetrievalError(_) => "ToolRetrievalError", LLMProviderError::ToolSearchError(_) => "ToolSearchError", + LLMProviderError::AgentNotFound(_) => "AgentNotFound", }; let error_message = format!("{}", self); diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/dsl_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/dsl_inference_chain.rs index 44a5531c6..8dd526d0a 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/dsl_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/dsl_inference_chain.rs @@ -15,7 +15,7 @@ use regex::Regex; use shinkai_baml::baml_builder::{BamlConfig, ClientConfig, GeneratorConfig}; use shinkai_dsl::dsl_schemas::Workflow; use shinkai_message_primitives::{ - schemas::{inbox_name::InboxName, job::JobLike}, + schemas::{inbox_name::InboxName, job::JobLike, llm_providers::common_agent_llm_provider::ProviderOrAgent}, shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}, }; use shinkai_sqlite::logger::{WorkflowLogEntry, WorkflowLogEntryStatus}; @@ -357,6 +357,7 @@ impl AsyncFunction for InferenceFunction { }, None, // this is the config self.context.llm_stopper().clone(), + self.context.db().clone(), ) .await .map_err(|e| WorkflowError::ExecutionError(e.to_string()))?; @@ -460,6 +461,7 @@ impl AsyncFunction for OpinionatedInferenceFunction { }, None, // this is the config self.context.llm_stopper().clone(), + self.context.db().clone(), ) .await .map_err(|e| WorkflowError::ExecutionError(e.to_string()))?; @@ -507,11 +509,26 @@ impl AsyncFunction for BamlInference { // TODO: do we need the job for something? // let full_job = self.context.full_job(); - let llm_provider = self.context.agent(); + let llm_or_agent_provider = self.context.agent(); let generator_config = GeneratorConfig::default(); - // TODO: add support for other providers + let llm_provider = match llm_or_agent_provider { + ProviderOrAgent::LLMProvider(provider) => provider, + ProviderOrAgent::Agent(agent) => { + let llm_provider_id = agent.llm_provider_id.clone(); + let profile = agent.full_identity_name.clone(); + let provider = self.context + .db() + .get_llm_provider(&llm_provider_id, &profile) + .map_err(|e| WorkflowError::ExecutionError(e.to_string()))?; + &provider + .as_ref() + .ok_or_else(|| WorkflowError::ExecutionError("LLM provider not found".to_string()))? + .clone() + } + }; + let base_url = llm_provider.external_url.clone().unwrap_or_default(); let base_url = if base_url == "http://localhost:11434" || base_url == "http://localhost:11435" { format!("{}/v1", base_url) @@ -619,7 +636,15 @@ impl AsyncFunction for MultiInferenceFunction { let mut responses = Vec::new(); let agent = self.context.agent(); - let max_tokens = ModelCapabilitiesManager::get_max_input_tokens(&agent.model); + let max_tokens = ModelCapabilitiesManager::get_max_input_tokens_for_provider_or_agent( + agent.clone(), + self.context.db().clone(), + ); + + let max_tokens = match max_tokens { + Some(tokens) => tokens, + None => return Err(WorkflowError::ExecutionError("Max tokens not found".to_string())), + }; for text in split_texts.iter() { let inference_args = vec![ diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/generic_functions.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/generic_functions.rs index eac511263..6abe584b6 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/generic_functions.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/generic_functions.rs @@ -399,6 +399,7 @@ pub fn search_embeddings_in_job_scope( #[cfg(test)] mod tests { use shinkai_db::db::ShinkaiDB; + use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{ LLMProviderInterface, OpenAI, SerializedLLMProvider, }; @@ -696,13 +697,9 @@ mod tests { let agent = SerializedLLMProvider { id: "test_agent_id".to_string(), full_identity_name: agent_name, - perform_locally: false, external_url: Some("https://api.openai.com".to_string()), api_key: Some("mockapikey".to_string()), model: LLMProviderInterface::OpenAI(open_ai), - toolkit_permissions: vec![], - storage_bucket_permissions: vec![], - allowed_message_senders: vec![], }; let image_files = HashMap::new(); @@ -717,7 +714,7 @@ mod tests { }, None, image_files, - agent, + ProviderOrAgent::LLMProvider(agent), HashMap::new(), generator, ShinkaiName::default_testnet_localhost(), @@ -832,13 +829,9 @@ mod tests { let agent = SerializedLLMProvider { id: "test_agent_id_with_query".to_string(), full_identity_name: agent_name, - perform_locally: false, external_url: Some("https://api.openai.com".to_string()), api_key: Some("mockapikey".to_string()), model: LLMProviderInterface::OpenAI(open_ai), - toolkit_permissions: vec![], - storage_bucket_permissions: vec![], - allowed_message_senders: vec![], }; let image_files = HashMap::new(); @@ -853,7 +846,7 @@ mod tests { }, None, image_files, - agent, + ProviderOrAgent::LLMProvider(agent), HashMap::new(), generator, ShinkaiName::default_testnet_localhost(), @@ -983,13 +976,9 @@ mod tests { let agent = SerializedLLMProvider { id: "test_agent_id".to_string(), full_identity_name: agent_name, - perform_locally: false, external_url: Some("https://api.openai.com".to_string()), api_key: Some("mockapikey".to_string()), model: LLMProviderInterface::OpenAI(open_ai), - toolkit_permissions: vec![], - storage_bucket_permissions: vec![], - allowed_message_senders: vec![], }; let image_files = HashMap::new(); @@ -1004,7 +993,7 @@ mod tests { }, None, image_files, - agent, + ProviderOrAgent::LLMProvider(agent), HashMap::new(), generator, ShinkaiName::default_testnet_localhost(), diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/split_text_for_llm.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/split_text_for_llm.rs index 3c35823d3..1e209b8c5 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/split_text_for_llm.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/dsl_chain/split_text_for_llm.rs @@ -15,7 +15,12 @@ pub fn split_text_for_llm( .clone(); let agent = context.agent(); - let max_tokens = ModelCapabilitiesManager::get_max_input_tokens(&agent.model); + let max_tokens = ModelCapabilitiesManager::get_max_input_tokens_for_provider_or_agent(agent.clone(), context.db().clone()); + + let max_tokens = match max_tokens { + Some(tokens) => tokens, + None => return Err(WorkflowError::ExecutionError("Max tokens not found".to_string())), + }; let mut result = Vec::new(); let current_text = input1.clone(); diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs index 00f80694e..8a243f75f 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs @@ -6,7 +6,6 @@ use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator use crate::llm_provider::execution::user_message_parser::ParsedUserMessage; use crate::llm_provider::job_manager::JobManager; use crate::llm_provider::llm_stopper::LLMStopper; -use crate::llm_provider::providers::shared::openai_api::FunctionCall; use crate::managers::model_capabilities_manager::ModelCapabilitiesManager; use crate::managers::sheet_manager::SheetManager; use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter}; @@ -19,9 +18,7 @@ use shinkai_db::schemas::ws_types::{ }; use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::job::{Job, JobLike}; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{ - LLMProviderInterface, SerializedLLMProvider, -}; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; @@ -111,7 +108,7 @@ impl GenericInferenceChain { user_message: String, message_hash_id: Option, image_files: HashMap, - llm_provider: SerializedLLMProvider, + llm_provider: ProviderOrAgent, execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, @@ -177,50 +174,82 @@ impl GenericInferenceChain { ); let mut tools = vec![]; let stream = job_config.as_ref().and_then(|config| config.stream); - let use_tools = ModelCapabilitiesManager::has_tool_capabilities(&llm_provider.model, stream); + let use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent( + llm_provider.clone(), + db.clone(), + stream, + ); if use_tools { - if let Some(tool_router) = &tool_router { - // TODO: enable back the default tools (must tools) - // // Get default tools - // if let Ok(default_tools) = tool_router.get_default_tools(&user_profile) { - // tools.extend(default_tools); - // } - - // Search in JS Tools - let results = tool_router - .vector_search_enabled_tools_with_network(&user_message.clone(), 5) - .await; - - match results { - Ok(results) => { - for result in results { - match tool_router.get_tool_by_name(&result.tool_router_key).await { - Ok(Some(tool)) => tools.push(tool), - Ok(None) => { - return Err(LLMProviderError::ToolNotFound( - format!("Tool not found for key: {}", result.tool_router_key), - )); - } - Err(e) => { - return Err(LLMProviderError::ToolRetrievalError( - format!("Error retrieving tool: {:?}", e), - )); - } + // If the llm_provider is an Agent, retrieve tools directly from the Agent struct + if let ProviderOrAgent::Agent(agent) = &llm_provider { + for tool_name in &agent.tools { + if let Some(tool_router) = &tool_router { + match tool_router.get_tool_by_name(tool_name).await { + Ok(Some(tool)) => tools.push(tool), + Ok(None) => { + return Err(LLMProviderError::ToolNotFound(format!( + "Tool not found for name: {}", + tool_name + ))); + } + Err(e) => { + return Err(LLMProviderError::ToolRetrievalError(format!( + "Error retrieving tool: {:?}", + e + ))); } } } - Err(e) => { - return Err(LLMProviderError::ToolSearchError( - format!("Error during tool search: {:?}", e), - )); + } + } else { + // If the llm_provider is not an Agent, perform a vector search for tools + if let Some(tool_router) = &tool_router { + let results = tool_router + .vector_search_enabled_tools_with_network(&user_message.clone(), 5) + .await; + + match results { + Ok(results) => { + for result in results { + match tool_router.get_tool_by_name(&result.tool_router_key).await { + Ok(Some(tool)) => tools.push(tool), + Ok(None) => { + return Err(LLMProviderError::ToolNotFound(format!( + "Tool not found for key: {}", + result.tool_router_key + ))); + } + Err(e) => { + return Err(LLMProviderError::ToolRetrievalError(format!( + "Error retrieving tool: {:?}", + e + ))); + } + } + } + } + Err(e) => { + return Err(LLMProviderError::ToolSearchError(format!( + "Error during tool search: {:?}", + e + ))); + } } } } } // 3) Generate Prompt - let custom_prompt = job_config.and_then(|config| config.custom_prompt.clone()); + // First, attempt to use the custom_prompt from the job's config. + // If it doesn't exist, fall back to the agent's custom_prompt if the llm_provider is an Agent. + let custom_prompt = job_config.and_then(|config| config.custom_prompt.clone()).or_else(|| { + if let ProviderOrAgent::Agent(agent) = &llm_provider { + agent.config.as_ref().and_then(|config| config.custom_prompt.clone()) + } else { + None + } + }); let mut filled_prompt = JobPromptGenerator::generic_inference_prompt( custom_prompt, @@ -257,6 +286,7 @@ impl GenericInferenceChain { ws_manager_trait.clone(), job_config.cloned(), llm_stopper.clone(), + db.clone(), ) .await; @@ -326,7 +356,13 @@ impl GenericInferenceChain { tool_calls_history.push(function_call_with_router_key); // Trigger WS update after receiving function_response - Self::trigger_ws_update(&ws_manager_trait, &Some(full_job.job_id.clone()), &function_response, shinkai_tool.tool_router_key()).await; + Self::trigger_ws_update( + &ws_manager_trait, + &Some(full_job.job_id.clone()), + &function_response, + shinkai_tool.tool_router_key(), + ) + .await; // 7) Call LLM again with the response (for formatting) filled_prompt = JobPromptGenerator::generic_inference_prompt( diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs index fab01c3e0..21951cf35 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs @@ -3,6 +3,7 @@ use super::inference_chain_trait::{InferenceChain, InferenceChainContext, Infere use super::sheet_ui_chain::sheet_ui_inference_chain::SheetUIInferenceChain; use shinkai_db::db::ShinkaiDB; use shinkai_message_primitives::schemas::job::Job; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_sqlite::SqliteLogger; use shinkai_vector_fs::vector_fs::vector_fs::VectorFS; use crate::llm_provider::error::LLMProviderError; @@ -15,7 +16,6 @@ use crate::managers::tool_router::ToolRouter; use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager; use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager; use shinkai_db::schemas::ws_types::WSUpdateHandler; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{AssociatedUI, JobMessage}; use shinkai_vector_resources::embedding_generator::RemoteEmbeddingGenerator; @@ -30,7 +30,7 @@ impl JobManager { pub async fn inference_chain_router( db: Arc, vector_fs: Arc, - llm_provider_found: Option, + llm_provider_found: Option, full_job: Job, job_message: JobMessage, message_hash_id: Option, @@ -48,7 +48,19 @@ impl JobManager { ) -> Result { // Initializations let llm_provider = llm_provider_found.ok_or(LLMProviderError::LLMProviderNotFound)?; - let max_tokens_in_prompt = ModelCapabilitiesManager::get_max_input_tokens(&llm_provider.model); + let model = { + if let ProviderOrAgent::LLMProvider(llm_provider) = llm_provider.clone() { + &llm_provider.model.clone() + } else { + // If it's an agent, we need to get the LLM provider from the agent + let llm_id = llm_provider.get_llm_provider_id(); + let llm_provider = db + .get_llm_provider(llm_id, &user_profile)? + .ok_or(LLMProviderError::LLMProviderNotFound)?; + &llm_provider.model.clone() + } + }; + let max_tokens_in_prompt = ModelCapabilitiesManager::get_max_input_tokens(&model); let parsed_user_message = ParsedUserMessage::new(job_message.content.to_string()); // Create the inference chain context diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs index 42a719ca2..b29a48beb 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs @@ -11,7 +11,7 @@ use serde_json::Value as JsonValue; use shinkai_db::db::ShinkaiDB; use shinkai_db::schemas::ws_types::WSUpdateHandler; use shinkai_message_primitives::schemas::job::Job; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::FunctionCallMetadata; use shinkai_sqlite::SqliteLogger; @@ -63,7 +63,7 @@ pub trait InferenceChainContextTrait: Send + Sync { fn user_message(&self) -> &ParsedUserMessage; fn message_hash_id(&self) -> Option; fn image_files(&self) -> &HashMap; - fn agent(&self) -> &SerializedLLMProvider; + fn agent(&self) -> &ProviderOrAgent; fn execution_context(&self) -> &HashMap; fn generator(&self) -> &RemoteEmbeddingGenerator; fn user_profile(&self) -> &ShinkaiName; @@ -129,7 +129,7 @@ impl InferenceChainContextTrait for InferenceChainContext { &self.image_files } - fn agent(&self) -> &SerializedLLMProvider { + fn agent(&self) -> &ProviderOrAgent { &self.llm_provider } @@ -204,7 +204,7 @@ pub struct InferenceChainContext { pub user_message: ParsedUserMessage, pub message_hash_id: Option, pub image_files: HashMap, - pub llm_provider: SerializedLLMProvider, + pub llm_provider: ProviderOrAgent, /// Job's execution context, used to store potentially relevant data across job steps. pub execution_context: HashMap, pub generator: RemoteEmbeddingGenerator, @@ -231,7 +231,7 @@ impl InferenceChainContext { user_message: ParsedUserMessage, message_hash_id: Option, image_files: HashMap, - agent: SerializedLLMProvider, + llm_provider: ProviderOrAgent, execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, @@ -252,7 +252,7 @@ impl InferenceChainContext { user_message, message_hash_id, image_files, - llm_provider: agent, + llm_provider, execution_context, generator, user_profile, @@ -443,7 +443,7 @@ impl InferenceChainContextTrait for Box { (**self).image_files() } - fn agent(&self) -> &SerializedLLMProvider { + fn agent(&self) -> &ProviderOrAgent { (**self).agent() } @@ -626,7 +626,7 @@ impl InferenceChainContextTrait for MockInferenceChainContext { &self.image_files } - fn agent(&self) -> &SerializedLLMProvider { + fn agent(&self) -> &ProviderOrAgent { unimplemented!() } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs index 3f19b0ed8..9368602ec 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs @@ -7,6 +7,7 @@ use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator use crate::llm_provider::execution::user_message_parser::ParsedUserMessage; use crate::llm_provider::job_manager::JobManager; use crate::llm_provider::llm_stopper::LLMStopper; +use crate::managers::model_capabilities_manager::ModelCapabilitiesManager; use crate::managers::sheet_manager::SheetManager; use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter}; use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager; @@ -16,9 +17,7 @@ use shinkai_db::db::ShinkaiDB; use shinkai_db::schemas::ws_types::WSUpdateHandler; use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::job::{Job, JobLike}; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{ - LLMProviderInterface, SerializedLLMProvider, -}; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_sqlite::SqliteLogger; @@ -111,7 +110,7 @@ impl SheetUIInferenceChain { user_message: String, message_hash_id: Option, image_files: HashMap, - llm_provider: SerializedLLMProvider, + llm_provider: ProviderOrAgent, execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, @@ -214,33 +213,13 @@ impl SheetUIInferenceChain { // 2) Vector search for tooling / workflows if the workflow / tooling scope isn't empty let job_config = full_job.config(); let mut tools = vec![]; - let use_tools = match &llm_provider.model { - LLMProviderInterface::OpenAI(_) => true, - LLMProviderInterface::Ollama(model_type) => { - let is_supported_model = model_type.model_type.starts_with("llama3.1") - || model_type.model_type.starts_with("llama3.2") - || model_type.model_type.starts_with("llama-3.1") - || model_type.model_type.starts_with("llama-3.2") - || model_type.model_type.starts_with("mistral-nemo") - || model_type.model_type.starts_with("mistral-small") - || model_type.model_type.starts_with("mistral-large"); - is_supported_model - && job_config - .as_ref() - .map_or(true, |config| config.stream.unwrap_or(true) == false) - } - LLMProviderInterface::Groq(model_type) => { - let is_supported_model = model_type.model_type.starts_with("llama-3.2") - || model_type.model_type.starts_with("llama3.2") - || model_type.model_type.starts_with("llama-3.1") - || model_type.model_type.starts_with("llama3.1"); - is_supported_model - && job_config - .as_ref() - .map_or(true, |config| config.stream.unwrap_or(true) == false) - } - _ => false, - }; + let stream = job_config.as_ref().and_then(|config| config.stream); + let use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent( + llm_provider.clone(), + db.clone(), + stream, + ); + if use_tools { tools.extend(SheetRustFunctions::sheet_rust_fn()); @@ -266,11 +245,11 @@ impl SheetUIInferenceChain { // 3) Generate Prompt let job_config = full_job.config(); - + let csv_result = { let sheet_manager_clone = sheet_manager.clone().unwrap(); let sheet_id_clone = sheet_id.clone(); - + // Export the current CSV data let csv_result = SheetRustFunctions::get_table(sheet_manager_clone, sheet_id_clone, HashMap::new()).await; @@ -280,7 +259,7 @@ impl SheetUIInferenceChain { }; csv_data - }; + }; // Extend the user message to include the CSV data if available let extended_user_message = if csv_result.is_empty() { @@ -327,6 +306,7 @@ impl SheetUIInferenceChain { ws_manager_trait.clone(), job_config.cloned(), llm_stopper.clone(), + db.clone(), ) .await; @@ -447,4 +427,4 @@ impl SheetUIInferenceChain { iteration_count += 1; } } -} \ No newline at end of file +} diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs index 0f926e734..748f1d9a3 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs @@ -16,6 +16,7 @@ use shinkai_dsl::dsl_schemas::Workflow; use shinkai_dsl::parser::parse_workflow; use shinkai_job_queue_manager::job_queue_manager::{JobForProcessing, JobQueueManager}; use shinkai_message_primitives::schemas::job::{Job, JobLike}; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; use shinkai_message_primitives::schemas::sheet::WorkflowSheetJobData; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{CallbackAction, MessageMetadata, WSTopic}; @@ -258,7 +259,7 @@ impl JobManager { job_message: JobMessage, message_hash_id: Option, full_job: Job, - llm_provider_found: Option, + llm_provider_found: Option, user_profile: ShinkaiName, generator: RemoteEmbeddingGenerator, ws_manager: Option>>, @@ -394,7 +395,7 @@ impl JobManager { vector_fs: Arc, job_message: &JobMessage, message_hash_id: Option, - llm_provider_found: Option, + llm_provider_found: Option, full_job: Job, identity_secret_key: SigningKey, generator: RemoteEmbeddingGenerator, @@ -548,7 +549,7 @@ impl JobManager { job_message: &JobMessage, message_hash_id: Option, message_content: String, - llm_provider_found: Option, + llm_provider_found: Option, full_job: Job, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, @@ -562,7 +563,19 @@ impl JobManager { llm_stopper: Arc, ) -> Result { let llm_provider = llm_provider_found.ok_or(LLMProviderError::LLMProviderNotFound)?; - let max_tokens_in_prompt = ModelCapabilitiesManager::get_max_input_tokens(&llm_provider.model); + let model = { + if let ProviderOrAgent::LLMProvider(llm_provider) = llm_provider.clone() { + &llm_provider.model.clone() + } else { + // If it's an agent, we need to get the LLM provider from the agent + let llm_id = llm_provider.get_llm_provider_id(); + let llm_provider = db + .get_llm_provider(llm_id, &user_profile)? + .ok_or(LLMProviderError::LLMProviderNotFound)?; + &llm_provider.model.clone() + } + }; + let max_tokens_in_prompt = ModelCapabilitiesManager::get_max_input_tokens(&model); let parsed_user_message = ParsedUserMessage::new(message_content); let full_execution_context = full_job.execution_context.clone(); let empty_files = HashMap::new(); @@ -574,7 +587,7 @@ impl JobManager { parsed_user_message, message_hash_id, empty_files, - llm_provider, + llm_provider.clone(), full_execution_context, generator, user_profile.clone(), @@ -648,7 +661,7 @@ impl JobManager { vector_fs: Arc, job_message: &JobMessage, message_hash_id: Option, - llm_provider_found: Option, + llm_provider_found: Option, full_job: Job, user_profile: ShinkaiName, generator: RemoteEmbeddingGenerator, @@ -823,7 +836,7 @@ impl JobManager { db: Arc, vector_fs: Arc, files: Vec<(String, Vec)>, - agent_found: Option, + agent_found: Option, full_job: &mut Job, profile: ShinkaiName, save_to_vector_fs_folder: Option, @@ -982,7 +995,7 @@ impl JobManager { db: Arc, vector_fs: Arc, job_message: &JobMessage, - agent_found: Option, + agent_found: Option, full_job: &mut Job, profile: ShinkaiName, save_to_vector_fs_folder: Option, @@ -1033,7 +1046,7 @@ impl JobManager { vector_fs: Arc, files_inbox: String, file_names: Vec, - agent_found: Option, + agent_found: Option, full_job: &mut Job, profile: ShinkaiName, save_to_vector_fs_folder: Option, @@ -1131,9 +1144,9 @@ impl JobManager { /// Else, the files will be returned as LocalScopeEntries and thus held inside. #[allow(clippy::too_many_arguments)] pub async fn process_files_inbox( - _db: Arc, + db: Arc, _vector_fs: Arc, - agent: Option, + agent: Option, files: Vec<(String, Vec)>, _profile: ShinkaiName, save_to_vector_fs_folder: Option, @@ -1168,7 +1181,7 @@ impl JobManager { dist_files.push((file.0, file.1, distribution_info)); } - let processed_vrkais = ParsingHelper::process_files_into_vrkai(dist_files, &generator, agent.clone()).await?; + let processed_vrkais = ParsingHelper::process_files_into_vrkai(dist_files, &generator, agent.clone(), db.clone()).await?; // Save the vrkai into scope (and potentially VectorFS) for (filename, vrkai) in processed_vrkais { diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_helpers.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_helpers.rs index 0b66cac05..0e2bd1b4b 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_helpers.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_helpers.rs @@ -10,28 +10,29 @@ use crate::llm_provider::llm_provider::LLMProvider; use crate::llm_provider::llm_stopper::LLMStopper; use shinkai_db::schemas::ws_types::WSUpdateHandler; use shinkai_message_primitives::schemas::inbox_name::InboxName; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use tokio::sync::Mutex; use std::result::Result::Ok; use std::sync::Arc; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; impl JobManager { /// Inferences the Agent's LLM with the given prompt. pub async fn inference_with_llm_provider( - llm_provider: SerializedLLMProvider, + llm_provider: ProviderOrAgent, filled_prompt: Prompt, inbox_name: Option, ws_manager_trait: Option>>, config: Option, llm_stopper: Arc, + db: Arc, ) -> Result { let llm_provider_cloned = llm_provider.clone(); let prompt_cloned = filled_prompt.clone(); let task_response = tokio::spawn(async move { - let llm_provider = LLMProvider::from_serialized_llm_provider(llm_provider_cloned); + let llm_provider = LLMProvider::from_provider_or_agent(llm_provider_cloned, db.clone())?; llm_provider.inference(prompt_cloned, inbox_name, ws_manager_trait, config, llm_stopper).await }) .await; @@ -51,29 +52,44 @@ impl JobManager { pub async fn fetch_relevant_job_data( job_id: &str, db: Arc, - ) -> Result<(Job, Option, String, Option), LLMProviderError> { + ) -> Result<(Job, Option, String, Option), LLMProviderError> { // Fetch the job let full_job = { db.get_job(job_id)? }; // Acquire Agent - let llm_provider_id = full_job.parent_llm_provider_id.clone(); - let mut llm_provider_found = None; + let agent_or_llm_provider_id = full_job.parent_agent_or_llm_provider_id.clone(); + let mut agent_or_llm_provider_found = None; let mut profile_name = String::new(); let mut user_profile: Option = None; - let llm_providers = JobManager::get_all_llm_providers(db).await.unwrap_or(vec![]); - for llm_provider in llm_providers { - if llm_provider.id == llm_provider_id { - llm_provider_found = Some(llm_provider.clone()); - profile_name.clone_from(&llm_provider.full_identity_name.full_name); - user_profile = Some(llm_provider.full_identity_name.extract_profile().unwrap()); + let agents_and_llm_providers = JobManager::get_all_agents_and_llm_providers(db).await.unwrap_or(vec![]); + for agent_or_llm_provider in agents_and_llm_providers { + if agent_or_llm_provider.get_id() == &agent_or_llm_provider_id { + agent_or_llm_provider_found = Some(agent_or_llm_provider.clone()); + profile_name.clone_from(&agent_or_llm_provider.get_full_identity_name().full_name); + user_profile = Some(agent_or_llm_provider.get_full_identity_name().extract_profile().unwrap()); break; } } - Ok((full_job, llm_provider_found, profile_name, user_profile)) + Ok((full_job, agent_or_llm_provider_found, profile_name, user_profile)) } - pub async fn get_all_llm_providers(db: Arc) -> Result, ShinkaiDBError> { - db.get_all_llm_providers() + pub async fn get_all_agents_and_llm_providers( + db: Arc, + ) -> Result, ShinkaiDBError> { + let llm_providers = db.get_all_llm_providers()?; + let agents = db.get_all_agents()?; + + let mut providers_and_agents = Vec::new(); + + for llm_provider in llm_providers { + providers_and_agents.push(ProviderOrAgent::LLMProvider(llm_provider)); + } + + for agent in agents { + providers_and_agents.push(ProviderOrAgent::Agent(agent)); + } + + Ok(providers_and_agents) } } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs b/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs index 8fb5029e1..80751b9a6 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs @@ -57,7 +57,6 @@ pub struct JobManager { pub jobs: Arc>>>, pub db: Weak, pub identity_manager: Arc>, - pub llm_providers: Vec>>, pub identity_secret_key: SigningKey, pub job_queue_manager: Arc>>, pub node_profile_name: ShinkaiName, @@ -94,17 +93,6 @@ impl JobManager { } } - // Get all serialized_llm_providers and convert them to LLM Providers - let mut llm_providers: Vec>> = Vec::new(); - { - let identity_manager = identity_manager.lock().await; - let serialized_llm_providers = identity_manager.get_all_llm_providers().await.unwrap(); - for serialized_agent in serialized_llm_providers { - let llm_provider = LLMProvider::from_serialized_llm_provider(serialized_agent); - llm_providers.push(Arc::new(Mutex::new(llm_provider))); - } - } - let db_prefix = "job_manager_abcdeprefix_"; let job_queue = JobQueueManager::::new( db.clone(), @@ -179,7 +167,6 @@ impl JobManager { node_profile_name, jobs: jobs_map, identity_manager, - llm_providers, job_queue_manager: job_queue_manager.clone(), job_processing_task: Some(job_queue_handler), ws_manager, @@ -456,8 +443,8 @@ impl JobManager { pub async fn process_job_creation( &mut self, job_creation: JobCreationInfo, - profile: &ShinkaiName, - llm_provider_id: &String, + _profile: &ShinkaiName, + llm_or_agent_provider_id: &String, ) -> Result { let job_id = format!("jobid_{}", uuid::Uuid::new_v4()); { @@ -465,7 +452,7 @@ impl JobManager { let is_hidden = job_creation.is_hidden.unwrap_or(false); match db_arc.create_new_job( job_id.clone(), - llm_provider_id.clone(), + llm_or_agent_provider_id.clone(), job_creation.scope, is_hidden, job_creation.associated_ui, @@ -479,35 +466,7 @@ impl JobManager { Ok(job) => { std::mem::drop(db_arc); // require to avoid deadlock self.jobs.lock().await.insert(job_id.clone(), Box::new(job)); - let mut llm_provider_found = None; - for agent in &self.llm_providers { - let locked_agent = agent.lock().await; - if &locked_agent.id == llm_provider_id { - llm_provider_found = Some(agent.clone()); - break; - } - } - - if llm_provider_found.is_none() { - let identity_manager = self.identity_manager.lock().await; - if let Some(serialized_agent) = identity_manager - .search_local_llm_provider(llm_provider_id, profile) - .await - { - let agent = LLMProvider::from_serialized_llm_provider(serialized_agent); - llm_provider_found = Some(Arc::new(Mutex::new(agent))); - if let Some(agent) = llm_provider_found.clone() { - self.llm_providers.push(agent); - } - } - } - - let job_id_to_return = match llm_provider_found { - Some(_) => Ok(job_id.clone()), - None => Err(anyhow::Error::new(LLMProviderError::LLMProviderNotFound)), - }; - - job_id_to_return.map_err(|_| LLMProviderError::LLMProviderNotFound) + Ok(job_id.clone()) } Err(err) => Err(LLMProviderError::ShinkaiDB(err)), } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/llm_provider.rs b/shinkai-bin/shinkai-node/src/llm_provider/llm_provider.rs index 0e459d543..f0db06e99 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/llm_provider.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/llm_provider.rs @@ -6,9 +6,12 @@ use super::llm_stopper::LLMStopper; use super::providers::LLMService; use reqwest::Client; use serde_json::{Map, Value as JsonValue}; +use shinkai_db::db::ShinkaiDB; use shinkai_db::schemas::ws_types::WSUpdateHandler; use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::job_config::JobConfig; +use shinkai_message_primitives::schemas::llm_providers::agent::Agent; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::schemas::prompts::Prompt; use shinkai_message_primitives::schemas::{ llm_providers::serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}, @@ -21,13 +24,10 @@ pub struct LLMProvider { pub id: String, pub full_identity_name: ShinkaiName, pub client: Client, - pub perform_locally: bool, // Todo: Remove as not used anymore pub external_url: Option, // external API URL pub api_key: Option, pub model: LLMProviderInterface, - pub toolkit_permissions: Vec, // Todo: remove as not used - pub storage_bucket_permissions: Vec, // Todo: remove as not used - pub allowed_message_senders: Vec, // list of sub-identities allowed to message the llm provider + pub agent: Option, } impl LLMProvider { @@ -35,13 +35,10 @@ impl LLMProvider { pub fn new( id: String, full_identity_name: ShinkaiName, - perform_locally: bool, external_url: Option, api_key: Option, model: LLMProviderInterface, - toolkit_permissions: Vec, - storage_bucket_permissions: Vec, - allowed_message_senders: Vec, + agent: Option, ) -> Self { let client = Client::builder() .timeout(std::time::Duration::from_secs(300)) // 5 min TTFT @@ -51,13 +48,10 @@ impl LLMProvider { id, full_identity_name, client, - perform_locally, external_url, api_key, model, - toolkit_permissions, - storage_bucket_permissions, - allowed_message_senders, + agent, } } @@ -88,6 +82,20 @@ impl LLMProvider { config: Option, llm_stopper: Arc, ) -> Result { + // Merge config with agent's config, preferring the provided config + let merged_config = if let Some(agent) = &self.agent { + if let Some(agent_config) = &agent.config { + // Prefer `config` over `agent_config` + Some(config.unwrap_or_else(JobConfig::empty).merge(agent_config)) + } else { + // Use provided config or create an empty one if none is provided + config.or_else(|| Some(JobConfig::empty())) + } + } else { + // Use provided config if no agent is present + config + }; + let response = match &self.model { LLMProviderInterface::OpenAI(openai) => { openai @@ -99,7 +107,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -114,7 +122,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -129,7 +137,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -143,7 +151,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -158,7 +166,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -172,7 +180,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -187,7 +195,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -202,7 +210,7 @@ impl LLMProvider { self.model.clone(), inbox_name, ws_manager_trait, - config, + merged_config, llm_stopper, ) .await @@ -220,13 +228,32 @@ impl LLMProvider { Self::new( serialized_llm_provider.id, serialized_llm_provider.full_identity_name, - serialized_llm_provider.perform_locally, serialized_llm_provider.external_url, serialized_llm_provider.api_key, serialized_llm_provider.model, - serialized_llm_provider.toolkit_permissions, - serialized_llm_provider.storage_bucket_permissions, - serialized_llm_provider.allowed_message_senders, + None, ) } + + pub fn from_provider_or_agent( + provider_or_agent: ProviderOrAgent, + db: Arc, + ) -> Result { + match provider_or_agent { + ProviderOrAgent::LLMProvider(serialized_llm_provider) => { + Ok(Self::from_serialized_llm_provider(serialized_llm_provider)) + } + ProviderOrAgent::Agent(agent) => { + let llm_id = &agent.llm_provider_id; + let llm_provider = db + .get_llm_provider(llm_id, &agent.full_identity_name) + .map_err(|_e| LLMProviderError::AgentNotFound(llm_id.clone()))?; + if let Some(llm_provider) = llm_provider { + Ok(Self::from_serialized_llm_provider(llm_provider)) + } else { + Err(LLMProviderError::AgentNotFound(llm_id.clone())) + } + } + } + } } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/llm_provider_to_serialization.rs b/shinkai-bin/shinkai-node/src/llm_provider/llm_provider_to_serialization.rs index 8c26ae8ed..036564b22 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/llm_provider_to_serialization.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/llm_provider_to_serialization.rs @@ -7,13 +7,9 @@ impl From for SerializedLLMProvider { SerializedLLMProvider { id: agent.id, full_identity_name: agent.full_identity_name, - perform_locally: agent.perform_locally, external_url: agent.external_url, api_key: agent.api_key, model: agent.model, - toolkit_permissions: agent.toolkit_permissions, - storage_bucket_permissions: agent.storage_bucket_permissions, - allowed_message_senders: agent.allowed_message_senders, } } } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/parsing_helper.rs b/shinkai-bin/shinkai-node/src/llm_provider/parsing_helper.rs index f11c0976f..19d5c94df 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/parsing_helper.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/parsing_helper.rs @@ -2,7 +2,8 @@ use super::error::LLMProviderError; use super::execution::prompts::general_prompts::JobPromptGenerator; use super::job_manager::JobManager; use super::llm_stopper::LLMStopper; -use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; +use shinkai_db::db::ShinkaiDB; +use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_vector_resources::embedding_generator::EmbeddingGenerator; use shinkai_vector_resources::file_parser::file_parser::ShinkaiFileParser; @@ -19,8 +20,9 @@ impl ParsingHelper { /// Given a list of TextGroup, generates a description using the Agent's LLM pub async fn generate_description( text_groups: &Vec, - agent: SerializedLLMProvider, + agent: ProviderOrAgent, max_node_text_size: u64, + db: Arc, ) -> Result { let descriptions = ShinkaiFileParser::process_groups_into_descriptions_list(text_groups, 10000, 300); let prompt = JobPromptGenerator::simple_doc_description(descriptions); @@ -35,6 +37,7 @@ impl ParsingHelper { None, None, llm_stopper.clone(), + db.clone(), ) .await { @@ -74,9 +77,10 @@ impl ParsingHelper { generator: &dyn EmbeddingGenerator, file_name: String, parsing_tags: &Vec, - agent: Option, + agent: Option, max_node_text_size: u64, distribution_info: DistributionInfo, + db: Arc, ) -> Result { let cleaned_name = ShinkaiFileParser::clean_name(&file_name); let source = VRSourceReference::from_file(&file_name, TextChunkingStrategy::V1)?; @@ -90,7 +94,7 @@ impl ParsingHelper { let mut desc = None; if let Some(actual_agent) = agent { - desc = Some(Self::generate_description(&text_groups, actual_agent, max_node_text_size).await?); + desc = Some(Self::generate_description(&text_groups, actual_agent, max_node_text_size, db.clone()).await?); } else { let description_text = ShinkaiFileParser::process_groups_into_description( &text_groups, @@ -120,7 +124,8 @@ impl ParsingHelper { pub async fn process_files_into_vrkai( files: Vec<(String, Vec, DistributionInfo)>, generator: &dyn EmbeddingGenerator, - agent: Option, + agent: Option, + db: Arc, ) -> Result, LLMProviderError> { #[allow(clippy::type_complexity)] let (vrkai_files, other_files): ( @@ -160,6 +165,7 @@ impl ParsingHelper { agent.clone(), (generator.model_type().max_input_token_count() - 20) as u64, file.2.clone(), + db.clone(), ) .await?; diff --git a/shinkai-bin/shinkai-node/src/managers/identity_manager.rs b/shinkai-bin/shinkai-node/src/managers/identity_manager.rs index 47e69d306..b54499955 100644 --- a/shinkai-bin/shinkai-node/src/managers/identity_manager.rs +++ b/shinkai-bin/shinkai-node/src/managers/identity_manager.rs @@ -253,15 +253,6 @@ impl IdentityManager { } } - pub async fn search_local_llm_provider( - &self, - agent_id: &str, - profile: &ShinkaiName, - ) -> Option { - let db_arc = self.db.upgrade()?; - db_arc.get_llm_provider(agent_id, profile).ok().flatten() - } - // Primarily for testing pub fn get_all_subidentities_devices_and_llm_providers(&self) -> Vec { self.local_identities.clone() diff --git a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs index 5e3e71cd4..e295708ea 100644 --- a/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs +++ b/shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs @@ -5,7 +5,7 @@ use crate::llm_provider::{ use shinkai_db::db::ShinkaiDB; use shinkai_message_primitives::schemas::{ llm_message::LlmMessage, - llm_providers::serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}, + llm_providers::{common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}}, prompts::Prompt, shinkai_name::ShinkaiName, }; @@ -436,6 +436,31 @@ impl ModelCapabilitiesManager { } } + /// Returns the maximum number of input tokens allowed for the given model, leaving room for output tokens. + pub fn get_max_input_tokens_for_provider_or_agent( + provider_or_agent: ProviderOrAgent, + db: Arc, + ) -> Option { + match provider_or_agent { + ProviderOrAgent::LLMProvider(serialized_llm_provider) => { + Some(ModelCapabilitiesManager::get_max_input_tokens(&serialized_llm_provider.model)) + } + ProviderOrAgent::Agent(agent) => { + let llm_id = &agent.llm_provider_id; + let profile = agent.full_identity_name.extract_profile().ok()?; + if let Some(llm_provider) = db.get_llm_provider(llm_id, &profile).ok() { + if let Some(model) = llm_provider { + Some(ModelCapabilitiesManager::get_max_input_tokens(&model.model)) + } else { + None + } + } else { + None + } + } + } + } + /// Returns the maximum number of input tokens allowed for the given model, leaving room for output tokens. pub fn get_max_input_tokens(model: &LLMProviderInterface) -> usize { let max_tokens = Self::get_max_tokens(model); @@ -604,6 +629,31 @@ impl ModelCapabilitiesManager { (num as f32 * 1.04) as usize } + /// Returns whether the given model supports tool/function calling capabilities + pub fn has_tool_capabilities_for_provider_or_agent( + provider_or_agent: ProviderOrAgent, + db: Arc, + stream: Option, + ) -> bool { + match provider_or_agent { + ProviderOrAgent::LLMProvider(serialized_llm_provider) => { + ModelCapabilitiesManager::has_tool_capabilities(&serialized_llm_provider.model, stream) + } + ProviderOrAgent::Agent(agent) => { + let llm_id = &agent.llm_provider_id; + if let Some(llm_provider) = db.get_llm_provider(llm_id, &agent.full_identity_name).ok() { + if let Some(model) = llm_provider { + ModelCapabilitiesManager::has_tool_capabilities(&model.model, stream) + } else { + false + } + } else { + false + } + } + } + } + /// Returns whether the given model supports tool/function calling capabilities pub fn has_tool_capabilities(model: &LLMProviderInterface, stream: Option) -> bool { eprintln!("has tool capabilities model: {:?}", model); diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 54aab118c..311e6b25f 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -2952,6 +2952,37 @@ impl Node { let _ = Node::v2_api_stop_llm(db_clone, stopper_clone, bearer, inbox_name, res).await; }); } + NodeCommand::V2ApiAddAgent { bearer, agent, res } => { + let db_clone = Arc::clone(&self.db); + let identity_manager_clone = self.identity_manager.clone(); + tokio::spawn(async move { + let _ = Node::v2_api_add_agent(db_clone, identity_manager_clone, bearer, agent, res).await; + }); + } + NodeCommand::V2ApiRemoveAgent { bearer, agent_id, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_remove_agent(db_clone, bearer, agent_id, res).await; + }); + } + NodeCommand::V2ApiUpdateAgent { bearer, partial_agent, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_update_agent(db_clone, bearer, partial_agent, res).await; + }); + } + NodeCommand::V2ApiGetAgent { bearer, agent_id, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_get_agent(db_clone, bearer, agent_id, res).await; + }); + } + NodeCommand::V2ApiGetAllAgents { bearer, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_get_all_agents(db_clone, bearer, res).await; + }); + } NodeCommand::V2ApiRetryMessage { bearer, inbox_name, diff --git a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_internal_commands.rs b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_internal_commands.rs index 4c8a8cefb..85b523379 100644 --- a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_internal_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_internal_commands.rs @@ -126,8 +126,6 @@ impl Node { return Vec::new(); } }; - // Start the timer - let start = Instant::now(); let result = match db.get_inboxes_for_profile(standard_identity) { Ok(inboxes) => inboxes, Err(e) => { @@ -139,9 +137,6 @@ impl Node { return Vec::new(); } }; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get all inboxes: {:?}", duration); result } @@ -674,15 +669,11 @@ impl Node { requester_profile.full_name, sanitized_model )) .expect("Failed to create ShinkaiName"), - perform_locally: false, external_url: Some(external_url.to_string()), api_key: Some("".to_string()), model: LLMProviderInterface::Ollama(Ollama { model_type: model.clone(), }), - toolkit_permissions: vec![], - storage_bucket_permissions: vec![], - allowed_message_senders: vec![], } }) .collect(); diff --git a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_vecfs_commands.rs b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_vecfs_commands.rs index 3dc1a37ad..575f79360 100644 --- a/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_vecfs_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v1_api/api_v1_vecfs_commands.rs @@ -1169,7 +1169,7 @@ impl Node { } // TODO: provide a default agent so that an LLM can be used to generate description of the VR for document files - let processed_vrkais = ParsingHelper::process_files_into_vrkai(dist_files, &*embedding_generator, None).await?; + let processed_vrkais = ParsingHelper::process_files_into_vrkai(dist_files, &*embedding_generator, None, db.clone()).await?; // Save the vrkais into VectorFS let mut success_messages = Vec::new(); diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs index 1745b0351..dd3359287 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs @@ -13,7 +13,7 @@ use shinkai_message_primitives::{ schemas::{ identity::{Identity, IdentityType, RegistrationCode}, inbox_name::InboxName, - llm_providers::serialized_llm_provider::SerializedLLMProvider, + llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::ShinkaiName, }, shinkai_message::{ @@ -951,4 +951,340 @@ impl Node { let _ = res.send(Ok(())).await; Ok(()) } + + pub async fn v2_api_add_agent( + db: Arc, + identity_manager: Arc>, + bearer: String, + agent: Agent, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the profile name from the identity manager + let requester_name = match identity_manager.lock().await.get_main_identity() { + Some(Identity::Standard(std_identity)) => std_identity.clone().full_identity_name, + _ => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Bad Request".to_string(), + message: "Wrong identity type. Expected Standard identity.".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Construct the expected full identity name + let expected_full_identity_name = ShinkaiName::new(format!( + "{}/main/agent/{}", + requester_name.get_node_name_string(), + agent.agent_id + )) + .unwrap(); + + // Check if the full identity name matches + if agent.full_identity_name != expected_full_identity_name { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Bad Request".to_string(), + message: "Invalid full identity name.".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + // TODO: validate tools + // TODO: validate knowledge + + // Check if the llm_provider_id exists + match db.get_llm_provider(&agent.llm_provider_id, &requester_name) { + Ok(Some(_)) => { + // Check if the agent_id already exists + match db.get_agent(&agent.agent_id) { + Ok(Some(_)) => { + let api_error = APIError { + code: StatusCode::CONFLICT.as_u16(), + error: "Conflict".to_string(), + message: "agent_id already exists".to_string(), + }; + let _ = res.send(Err(api_error)).await; + } + Ok(None) => { + // Add the agent to the database + match db.add_agent(agent, &requester_name) { + Ok(_) => { + let _ = res.send(Ok("Agent added successfully".to_string())).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to add agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to check agent_id: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: "llm_provider_id does not exist".to_string(), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to check llm_provider_id: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_remove_agent( + db: Arc, + bearer: String, + agent_id: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Remove the agent from the database + match db.remove_agent(&agent_id) { + Ok(_) => { + let _ = res.send(Ok("Agent removed successfully".to_string())).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to remove agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_update_agent( + db: Arc, + bearer: String, + partial_agent: serde_json::Value, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Extract agent_id from partial_agent + let agent_id = match partial_agent.get("agent_id").and_then(|id| id.as_str()) { + Some(id) => id.to_string(), + None => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Bad Request".to_string(), + message: "agent_id is missing in the request".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Retrieve the existing agent from the database + let existing_agent = match db.get_agent(&agent_id) { + Ok(Some(agent)) => agent, + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: "Agent not found".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Database error: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Construct the full identity name + let full_identity_name = match ShinkaiName::new(format!( + "{}/main/agent/{}", + existing_agent.full_identity_name.get_node_name_string(), + agent_id + )) { + Ok(name) => name, + Err(_) => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Bad Request".to_string(), + message: "Failed to construct full identity name.".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Manually merge fields from partial_agent with existing_agent + let updated_agent = Agent { + name: partial_agent + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or(&existing_agent.name) + .to_string(), + agent_id: existing_agent.agent_id.clone(), // Keep the original agent_id + llm_provider_id: partial_agent + .get("llm_provider_id") + .and_then(|v| v.as_str()) + .unwrap_or(&existing_agent.llm_provider_id) + .to_string(), + // TODO: decide if we keep this + // instructions: partial_agent + // .get("instructions") + // .and_then(|v| v.as_str()) + // .unwrap_or(&existing_agent.instructions) + // .to_string(), + ui_description: partial_agent + .get("ui_description") + .and_then(|v| v.as_str()) + .unwrap_or(&existing_agent.ui_description) + .to_string(), + knowledge: partial_agent + .get("knowledge") + .and_then(|v| v.as_array()) + .map_or(existing_agent.knowledge.clone(), |v| { + v.iter().filter_map(|s| s.as_str().map(String::from)).collect() + }), + storage_path: partial_agent + .get("storage_path") + .and_then(|v| v.as_str()) + .unwrap_or(&existing_agent.storage_path) + .to_string(), + tools: partial_agent + .get("tools") + .and_then(|v| v.as_array()) + .map_or(existing_agent.tools.clone(), |v| { + v.iter().filter_map(|s| s.as_str().map(String::from)).collect() + }), + debug_mode: partial_agent + .get("debug_mode") + .and_then(|v| v.as_bool()) + .unwrap_or(existing_agent.debug_mode), + config: partial_agent.get("config").map_or(existing_agent.config.clone(), |v| { + serde_json::from_value(v.clone()).unwrap_or(existing_agent.config.clone()) + }), + full_identity_name, // Set the constructed full identity name + }; + + // Update the agent in the database + match db.update_agent(updated_agent.clone()) { + Ok(_) => { + let _ = res.send(Ok(updated_agent)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to update agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_get_agent( + db: Arc, + bearer: String, + agent_id: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the agent from the database + match db.get_agent(&agent_id) { + Ok(Some(agent)) => { + let _ = res.send(Ok(agent)).await; + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: "Agent not found".to_string(), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_get_all_agents( + db: Arc, + bearer: String, + res: Sender, APIError>>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve all agents from the database + match db.get_all_agents() { + Ok(agents) => { + let _ = res.send(Ok(agents)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve agents: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs index 7a73b31b6..ecb805682 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs @@ -200,7 +200,7 @@ impl Node { // Retrieve the job to get the llm_provider let llm_provider = match db.get_job_with_options(&job_message.job_id, false, false) { - Ok(job) => job.parent_llm_provider_id.clone(), + Ok(job) => job.parent_agent_or_llm_provider_id.clone(), Err(err) => { let api_error = APIError { code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), @@ -444,8 +444,6 @@ impl Node { } }; - // Start the timer - let start = Instant::now(); // Retrieve all smart inboxes for the profile let smart_inboxes = match db.get_all_smart_inboxes_for_profile(main_identity) { Ok(inboxes) => inboxes, @@ -460,21 +458,12 @@ impl Node { } }; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get all inboxes: {:?}", duration); - - let start = Instant::now(); // Convert SmartInbox to V2SmartInbox let v2_smart_inboxes: Result, NodeError> = smart_inboxes .into_iter() .map(Self::convert_smart_inbox_to_v2_smart_inbox) .collect(); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to convert smart inboxes: {:?}", duration); - match v2_smart_inboxes { Ok(inboxes) => { let _ = res.send(Ok(inboxes)).await; @@ -1209,7 +1198,7 @@ impl Node { node_name.node_name, "main".to_string(), ShinkaiSubidentityType::Agent, - source_job.parent_llm_provider_id.clone(), + source_job.parent_agent_or_llm_provider_id.clone(), ) { Ok(name) => name, Err(err) => { @@ -1243,7 +1232,7 @@ impl Node { let forked_job_id = format!("jobid_{}", uuid::Uuid::new_v4()); match db.create_new_job( forked_job_id.clone(), - source_job.parent_llm_provider_id, + source_job.parent_agent_or_llm_provider_id, source_job.scope_with_files.clone().unwrap(), source_job.is_hidden, source_job.associated_ui, diff --git a/shinkai-bin/shinkai-node/src/utils/environment.rs b/shinkai-bin/shinkai-node/src/utils/environment.rs index 9bc9dd0b6..5d57edc51 100644 --- a/shinkai-bin/shinkai-node/src/utils/environment.rs +++ b/shinkai-bin/shinkai-node/src/utils/environment.rs @@ -76,13 +76,9 @@ pub fn fetch_llm_provider_env(global_identity: String) -> Vec Result<(), ShinkaiDBError> { + // Construct the database key for the agent + let agent_id_for_db = Self::db_llm_provider_id(&agent.agent_id, profile)?; + + // Validate the new ShinkaiName + let agent_name_str = format!( + "{}/{}/agent/{}", + profile.node_name, + profile.profile_name.clone().unwrap_or_default(), + agent.agent_id + ); + let _agent_name = ShinkaiName::new(agent_name_str.clone()).map_err(|_| { + ShinkaiDBError::InvalidIdentityName(format!("Invalid ShinkaiName: {}", agent_name_str)) + })?; + + // Check for collision with llm_provider_id + let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; + let llm_provider_key = format!("agent_placeholder_value_to_match_prefix_abcdef_{}", agent_id_for_db); + let llm_provider_exists = self.db.get_cf(cf_node_and_users, llm_provider_key.as_bytes())?.is_some(); + if llm_provider_exists { + return Err(ShinkaiDBError::IdCollision(format!( + "ID collision detected for agent_id: {}", + agent.agent_id + ))); + } + + // Serialize the agent to bytes + let bytes = to_vec(&agent).unwrap(); + let agent_key = format!("new_agentic_placeholder_values_to_match_prefix_{}", agent.agent_id); + + // Add the agent to the database under NodeAndUsers + self.db.put_cf(cf_node_and_users, agent_key.as_bytes(), bytes)?; + + Ok(()) + } + + pub fn remove_agent(&self, agent_id: &str) -> Result<(), ShinkaiDBError> { + let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; + let agent_key = format!("new_agentic_placeholder_values_to_match_prefix_{}", agent_id); + + // Check if the agent exists + let agent_exists = self.db.get_cf(cf_node_and_users, agent_key.as_bytes())?.is_some(); + if !agent_exists { + return Err(ShinkaiDBError::DataNotFound); + } + + // Remove the agent from the database + self.db.delete_cf(cf_node_and_users, agent_key.as_bytes())?; + + Ok(()) + } + + pub fn get_all_agents(&self) -> Result, ShinkaiDBError> { + let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; + let mut result = Vec::new(); + let prefix = b"new_agentic_placeholder_values_to_match_prefix_"; + + let iter = self.db.prefix_iterator_cf(cf_node_and_users, prefix); + for item in iter { + match item { + Ok((_, value)) => { + let agent: Agent = from_slice(value.as_ref()).unwrap(); + result.push(agent); + } + Err(e) => return Err(ShinkaiDBError::RocksDBError(e)), + } + } + + Ok(result) + } + + pub fn get_agent(&self, agent_id: &str) -> Result, ShinkaiDBError> { + let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; + let agent_key = format!("new_agentic_placeholder_values_to_match_prefix_{}", agent_id); + let agent_bytes = self.db.get_cf(cf_node_and_users, agent_key.as_bytes())?; + + if let Some(bytes) = agent_bytes { + let agent: Agent = from_slice(&bytes)?; + Ok(Some(agent)) + } else { + Ok(None) + } + } + + pub fn update_agent(&self, updated_agent: Agent) -> Result<(), ShinkaiDBError> { + let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; + let agent_key = format!("new_agentic_placeholder_values_to_match_prefix_{}", updated_agent.agent_id); + + // Check if the agent exists + let agent_exists = self.db.get_cf(cf_node_and_users, agent_key.as_bytes())?.is_some(); + if !agent_exists { + return Err(ShinkaiDBError::DataNotFound); + } + + // Serialize the updated agent to bytes + let bytes = to_vec(&updated_agent).unwrap(); + + // Update the agent in the database + self.db.put_cf(cf_node_and_users, agent_key.as_bytes(), bytes)?; + + Ok(()) + } +} \ No newline at end of file diff --git a/shinkai-libs/shinkai-db/src/db/db_errors.rs b/shinkai-libs/shinkai-db/src/db/db_errors.rs index f31cdcf45..08e6548b8 100644 --- a/shinkai-libs/shinkai-db/src/db/db_errors.rs +++ b/shinkai-libs/shinkai-db/src/db/db_errors.rs @@ -58,6 +58,7 @@ pub enum ShinkaiDBError { WorkflowNotFound(String), SheetNotFound(String), Other(String), + IdCollision(String), } impl fmt::Display for ShinkaiDBError { @@ -125,6 +126,7 @@ impl fmt::Display for ShinkaiDBError { ShinkaiDBError::WorkflowNotFound(e) => write!(f, "Workflow not found: {}", e), ShinkaiDBError::SheetNotFound(e) => write!(f, "Sheet not found: {}", e), ShinkaiDBError::Other(e) => write!(f, "Other error: {}", e), + ShinkaiDBError::IdCollision(e) => write!(f, "Id collision: {}", e), } } } @@ -188,6 +190,7 @@ impl PartialEq for ShinkaiDBError { } (ShinkaiDBError::DeviceNameNonExistent(msg1), ShinkaiDBError::DeviceNameNonExistent(msg2)) => msg1 == msg2, (ShinkaiDBError::Other(msg1), ShinkaiDBError::Other(msg2)) => msg1 == msg2, + (ShinkaiDBError::IdCollision(msg1), ShinkaiDBError::IdCollision(msg2)) => msg1 == msg2, _ => false, } } diff --git a/shinkai-libs/shinkai-db/src/db/db_inbox.rs b/shinkai-libs/shinkai-db/src/db/db_inbox.rs index ba186cc08..7f58bc527 100644 --- a/shinkai-libs/shinkai-db/src/db/db_inbox.rs +++ b/shinkai-libs/shinkai-db/src/db/db_inbox.rs @@ -489,15 +489,8 @@ impl ShinkaiDB { &self, profile_name_identity: StandardIdentity, ) -> Result, ShinkaiDBError> { - // Start the timer - let start = Instant::now(); - let inboxes = self.get_inboxes_for_profile(profile_name_identity.clone())?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get inboxes: {:?}", duration); - let mut smart_inboxes = Vec::new(); for inbox_id in inboxes { @@ -510,10 +503,6 @@ impl ShinkaiDB { .next() .and_then(|mut v| v.pop()); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get last message: {:?}", duration); - let cf_inbox = self.get_cf_handle(Topic::Inbox).unwrap(); let inbox_smart_inbox_name_key = format!("{}_smart_inbox_name", &inbox_id); let custom_name = match self.db.get_cf(cf_inbox, inbox_smart_inbox_name_key.as_bytes())? { @@ -543,9 +532,6 @@ impl ShinkaiDB { false }; - // Start the timer - let start = Instant::now(); - let agent_subset = { let profile_result = profile_name_identity.full_identity_name.clone().extract_profile(); match profile_result { @@ -554,13 +540,9 @@ impl ShinkaiDB { match InboxName::new(inbox_id.clone())? { InboxName::JobInbox { unique_id, .. } => { // Start the timer - let start = Instant::now(); let job = self.get_job_with_options(&unique_id, false, false)?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get job: {:?}", duration); + let agent_id = job.parent_agent_or_llm_provider_id; - let agent_id = job.parent_llm_provider_id; // TODO: add caching so we don't call this every time for the same agent_id match self.get_llm_provider(&agent_id, &p) { Ok(agent) => agent.map(LLMProviderSubset::from_serialized_llm_provider), @@ -577,10 +559,6 @@ impl ShinkaiDB { } }; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get agent subset: {:?}", duration); - let smart_inbox = SmartInbox { inbox_id: inbox_id.clone(), custom_name, @@ -595,9 +573,6 @@ impl ShinkaiDB { smart_inboxes.push(smart_inbox); } - // Start the timer - let start = Instant::now(); - // Sort the smart_inboxes by the timestamp of the last message smart_inboxes.sort_by(|a, b| match (&a.last_message, &b.last_message) { (Some(a_msg), Some(b_msg)) => { @@ -610,10 +585,6 @@ impl ShinkaiDB { (None, None) => std::cmp::Ordering::Equal, }); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to sort smart inboxes: {:?}", duration); - Ok(smart_inboxes) } diff --git a/shinkai-libs/shinkai-db/src/db/db_jobs.rs b/shinkai-libs/shinkai-db/src/db/db_jobs.rs index 2445e2125..ec551708d 100644 --- a/shinkai-libs/shinkai-db/src/db/db_jobs.rs +++ b/shinkai-libs/shinkai-db/src/db/db_jobs.rs @@ -225,7 +225,7 @@ impl ShinkaiDB { is_hidden, datetime_created, is_finished, - parent_llm_provider_id: parent_agent_id, + parent_agent_or_llm_provider_id: parent_agent_id, scope, scope_with_files, conversation_inbox_name: conversation_inbox, @@ -278,7 +278,7 @@ impl ShinkaiDB { is_hidden, datetime_created, is_finished, - parent_llm_provider_id: parent_agent_id, + parent_agent_or_llm_provider_id: parent_agent_id, scope, scope_with_files, conversation_inbox_name: conversation_inbox, @@ -329,9 +329,6 @@ impl ShinkaiDB { let cf_jobs = self.get_cf_handle(Topic::Inbox).unwrap(); // Begin fetching the data from the DB - // Start the timer - let start = Instant::now(); - let scope_value = self .db .get_cf(cf_jobs, format!("jobinbox_{}_scope", job_id).as_bytes())? @@ -344,8 +341,6 @@ impl ShinkaiDB { Ok::(MinimalJobScope::from(&job_scope)) })?; - eprintln!("Scope: {:?}", scope); - let mut scope_with_files: Option = None; if fetch_scope_with_files { let scope_with_files_value = if let Some(value) = self @@ -363,23 +358,12 @@ impl ShinkaiDB { scope_with_files = Some(JobScope::from_bytes(&scope_with_files_value)?); } - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get scope: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let is_finished_value = self .db .get_cf(cf_jobs, format!("jobinbox_{}_is_finished", job_id).as_bytes())? .ok_or(ShinkaiDBError::DataNotFound)?; let is_finished = std::str::from_utf8(&is_finished_value)? == "true"; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get is_finished: {:?}", duration); - // Start the timer let start = Instant::now(); @@ -389,26 +373,12 @@ impl ShinkaiDB { .ok_or(ShinkaiDBError::DataNotFound)?; let datetime_created = std::str::from_utf8(&datetime_created_value)?.to_string(); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get datetime_created: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let parent_agent_id_value = self .db .get_cf(cf_jobs, format!("jobinbox_{}_agentid", job_id).as_bytes())? .ok_or(ShinkaiDBError::DataNotFound)?; let parent_agent_id = std::str::from_utf8(&parent_agent_id_value)?.to_string(); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get parent_agent_id: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let job_inbox_name = self .db .get_cf(cf_jobs, format!("jobinbox_{}_inboxname", job_id).as_bytes())? @@ -416,36 +386,15 @@ impl ShinkaiDB { let inbox_name = std::str::from_utf8(&job_inbox_name)?.to_string(); let conversation_inbox = InboxName::new(inbox_name)?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get conversation_inbox: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let is_hidden_value = self .db .get_cf(cf_jobs, format!("jobinbox_{}_is_hidden", job_id).as_bytes())? .unwrap_or_else(|| b"false".to_vec()); let is_hidden = std::str::from_utf8(&is_hidden_value)? == "true"; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get is_hidden: {:?}", duration); - - // Start the timer - let start = Instant::now(); - // Reads all of the step history by iterating let step_history = self.get_step_history(job_id, fetch_step_history)?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get step_history: {:?}", duration); - - // Start the timer - let start = Instant::now(); - // Try to read associated_ui let associated_ui_value = self .db @@ -454,13 +403,6 @@ impl ShinkaiDB { .flatten() .and_then(|value| serde_json::from_slice(&value).ok()); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get associated_ui: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let config_value = self .db .get_cf(cf_jobs, format!("jobinbox_{}_config", job_id).as_bytes()) @@ -468,19 +410,8 @@ impl ShinkaiDB { .flatten() .and_then(|value| serde_json::from_slice(&value).ok()); - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get config: {:?}", duration); - - // Start the timer - let start = Instant::now(); - let forked_jobs = self.get_forked_jobs(job_id)?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get forked_jobs: {:?}", duration); - Ok(( scope, scope_with_files, @@ -928,9 +859,6 @@ impl ShinkaiDB { /// Fetches all forked jobs for a specific Job from the DB fn get_forked_jobs(&self, job_id: &str) -> Result, ShinkaiDBError> { - // Start the timer - let start = Instant::now(); - let cf_inbox = self.get_cf_handle(Topic::Inbox).unwrap(); // TODO: this is wrong let forked_jobs_key = format!("jobinbox_{}_forked_jobs", job_id); @@ -938,9 +866,6 @@ impl ShinkaiDB { match self.db.get_cf(cf_inbox, forked_jobs_key.as_bytes()) { Ok(Some(value)) => { let forked_jobs: Vec = serde_json::from_slice(&value)?; - // Measure the elapsed time - let duration = start.elapsed(); - println!("Time taken to get forked jobs: {:?}", duration); Ok(forked_jobs) } Ok(None) => Ok(Vec::new()), diff --git a/shinkai-libs/shinkai-db/src/db/db_llm_providers.rs b/shinkai-libs/shinkai-db/src/db/db_llm_providers.rs index 4fa4fe5ae..8629035d3 100644 --- a/shinkai-libs/shinkai-db/src/db/db_llm_providers.rs +++ b/shinkai-libs/shinkai-db/src/db/db_llm_providers.rs @@ -63,17 +63,6 @@ impl ShinkaiDB { ); batch.put_cf(cf_node_and_users, profile_key.as_bytes(), []); - // Additionally, for each allowed message sender and toolkit permission, - // you can store them with a specific prefix to indicate their relationship to the llm provider. - for profile in &llm_provider.allowed_message_senders { - let profile_key = format!("agent_{}_profile_{}", &llm_provider_id_for_db_hash, profile); - batch.put_cf(cf_node_and_users, profile_key.as_bytes(), []); - } - for toolkit in &llm_provider.toolkit_permissions { - let toolkit_key = format!("agent_{}_toolkit_{}", &llm_provider_id_for_db_hash, toolkit); - batch.put_cf(cf_node_and_users, toolkit_key.as_bytes(), []); - } - // Write the batch self.db.write(batch)?; @@ -238,38 +227,6 @@ impl ShinkaiDB { Ok(profiles_with_access) } - pub fn get_llm_provider_toolkits_accessible( - &self, - llm_provider_id: &str, - profile: &ShinkaiName, - ) -> Result, ShinkaiDBError> { - let cf_node_and_users = self.cf_handle(Topic::NodeAndUsers.as_str())?; - let llm_provider_id_for_db = Self::db_llm_provider_id(llm_provider_id, profile)?; - let llm_provider_id_for_db_hash = Self::llm_provider_id_to_hash(&llm_provider_id_for_db); - let prefix = format!("agent_{}_toolkit_", llm_provider_id_for_db_hash); - let mut toolkits_accessible = Vec::new(); - - let iter = self.db.prefix_iterator_cf(cf_node_and_users, prefix.as_bytes()); - for item in iter { - match item { - Ok((key, _)) => { - // Extract toolkit name from the key - let key_str = String::from_utf8(key.to_vec()) - .map_err(|_| ShinkaiDBError::DataConversionError("UTF-8 conversion error".to_string()))?; - // Ensure the key follows the prefix convention before extracting the toolkit name - if key_str.starts_with(&prefix) { - if let Some(toolkit_name) = key_str.split('_').last() { - toolkits_accessible.push(toolkit_name.to_string()); - } - } - } - Err(e) => return Err(ShinkaiDBError::RocksDBError(e)), - } - } - - Ok(toolkits_accessible) - } - pub fn remove_profile_from_llm_provider_access( &self, llm_provider_id: &str, diff --git a/shinkai-libs/shinkai-db/src/db/mod.rs b/shinkai-libs/shinkai-db/src/db/mod.rs index 9abddd6c6..83472b2b8 100644 --- a/shinkai-libs/shinkai-db/src/db/mod.rs +++ b/shinkai-libs/shinkai-db/src/db/mod.rs @@ -24,4 +24,5 @@ pub mod db_uploaded_files_links; pub mod db_sheet; pub mod db_invoice; pub mod db_internal_invoice_request; -pub mod db_wallets; \ No newline at end of file +pub mod db_wallets; +pub mod db_agent; \ No newline at end of file diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs index 00a8dd9f3..55a573882 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs @@ -2,6 +2,7 @@ use async_channel::Sender; use reqwest::StatusCode; use serde::Deserialize; use serde_json::json; +use shinkai_message_primitives::schemas::llm_providers::agent::Agent; use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{ Exo, Gemini, Groq, LLMProviderInterface, LocalLLM, Ollama, OpenAI, ShinkaiBackend, }; @@ -143,6 +144,40 @@ pub fn general_routes( .and(warp::body::json()) .and_then(stop_llm_handler); + let add_agent_route = warp::path("add_agent") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(add_agent_handler); + + let remove_agent_route = warp::path("remove_agent") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(remove_agent_handler); + + let update_agent_route = warp::path("update_agent") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(update_agent_handler); + + let get_agent_route = warp::path("get_agent") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::path::param::()) + .and_then(get_agent_handler); + + let get_all_agents_route = warp::path("get_all_agents") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and_then(get_all_agents_handler); + public_keys_route .or(health_check_route) .or(initial_registration_route) @@ -160,6 +195,11 @@ pub fn general_routes( .or(download_file_from_inbox_route) .or(list_files_in_inbox_route) .or(stop_llm_route) + .or(add_agent_route) + .or(remove_agent_route) + .or(update_agent_route) + .or(get_agent_route) + .or(get_all_agents_route) } #[derive(Deserialize)] @@ -723,6 +763,169 @@ pub async fn stop_llm_handler( } } +#[utoipa::path( + post, + path = "/v2/add_agent", + request_body = Agent, + responses( + (status = 200, description = "Successfully added agent", body = String), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn add_agent_handler( + sender: Sender, + authorization: String, + agent: Agent, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiAddAgent { + bearer, + agent, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => Ok(warp::reply::json(&response)), + Err(error) => Err(warp::reject::custom(error)), + } +} + +#[utoipa::path( + post, + path = "/v2/remove_agent", + request_body = HashMap, + responses( + (status = 200, description = "Successfully removed agent", body = String), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn remove_agent_handler( + sender: Sender, + authorization: String, + payload: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let agent_id = payload.get("agent_id").cloned().unwrap_or_default(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiRemoveAgent { + bearer, + agent_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => Ok(warp::reply::json(&response)), + Err(error) => Err(warp::reject::custom(error)), + } +} + +#[utoipa::path( + post, + path = "/v2/update_agent", + request_body = serde_json::Value, + responses( + (status = 200, description = "Successfully updated agent", body = Agent), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn update_agent_handler( + sender: Sender, + authorization: String, + update_data: serde_json::Value, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiUpdateAgent { + bearer, + partial_agent: update_data, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(updated_agent) => Ok(warp::reply::json(&updated_agent)), + Err(error) => Err(warp::reject::custom(error)), + } +} + +#[utoipa::path( + get, + path = "/v2/get_agent/{agent_id}", + responses( + (status = 200, description = "Successfully retrieved agent", body = Agent), + (status = 404, description = "Agent not found", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn get_agent_handler( + sender: Sender, + authorization: String, + agent_id: String, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiGetAgent { + bearer, + agent_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(agent) => Ok(warp::reply::json(&agent)), + Err(error) => Err(warp::reject::custom(error)), + } +} + +#[utoipa::path( + get, + path = "/v2/get_all_agents", + responses( + (status = 200, description = "Successfully retrieved all agents", body = Vec), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn get_all_agents_handler( + sender: Sender, + authorization: String, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiGetAllAgents { + bearer, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(agents) => Ok(warp::reply::json(&agents)), + Err(error) => Err(warp::reject::custom(error)), + } +} + #[derive(OpenApi)] #[openapi( paths( @@ -743,6 +946,11 @@ pub async fn stop_llm_handler( scan_ollama_models_handler, add_ollama_models_handler, stop_llm_handler, + add_agent_handler, + remove_agent_handler, + update_agent_handler, + get_agent_handler, + get_all_agents_handler, ), components( schemas(APIAddOllamaModels, SerializedLLMProvider, ShinkaiName, LLMProviderInterface, @@ -750,7 +958,7 @@ pub async fn stop_llm_handler( OpenAI, Ollama, LocalLLM, Groq, Gemini, Exo, EncryptedShinkaiBody, ShinkaiBody, ShinkaiSubidentityType, ShinkaiBackend, InternalMetadata, MessageData, StopLLMRequest, NodeApiData, EncryptedShinkaiData, ShinkaiData, MessageSchemaType, - APIUseRegistrationCodeSuccessResponse, GetPublicKeysResponse, APIError) + APIUseRegistrationCodeSuccessResponse, GetPublicKeysResponse, APIError, Agent) ), tags( (name = "general", description = "General API endpoints") diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index a9ef4034c..3995e83b7 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -6,17 +6,7 @@ use ed25519_dalek::VerifyingKey; use serde_json::Value; use shinkai_message_primitives::{ schemas::{ - coinbase_mpc_config::CoinbaseMPCWalletConfig, - custom_prompt::CustomPrompt, - identity::{Identity, StandardIdentity}, - job_config::JobConfig, - llm_providers::serialized_llm_provider::SerializedLLMProvider, - shinkai_name::ShinkaiName, - shinkai_subscription::ShinkaiSubscription, - shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry}, - smart_inbox::{SmartInbox, V2SmartInbox}, - wallet_complementary::{WalletRole, WalletSource}, - wallet_mixed::NetworkIdentifier, + llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider}, coinbase_mpc_config::CoinbaseMPCWalletConfig, custom_prompt::CustomPrompt, identity::{Identity, StandardIdentity}, job_config::JobConfig, shinkai_name::ShinkaiName, shinkai_subscription::ShinkaiSubscription, shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry}, smart_inbox::{SmartInbox, V2SmartInbox}, wallet_complementary::{WalletRole, WalletSource}, wallet_mixed::NetworkIdentifier }, shinkai_message::{ shinkai_message::ShinkaiMessage, @@ -939,6 +929,30 @@ pub enum NodeCommand { inbox_name: String, res: Sender>, }, + V2ApiAddAgent { + bearer: String, + agent: Agent, + res: Sender>, + }, + V2ApiRemoveAgent { + bearer: String, + agent_id: String, + res: Sender>, + }, + V2ApiUpdateAgent { + bearer: String, + partial_agent: serde_json::Value, + res: Sender>, + }, + V2ApiGetAgent { + bearer: String, + agent_id: String, + res: Sender>, + }, + V2ApiGetAllAgents { + bearer: String, + res: Sender, APIError>>, + }, V2ApiRetryMessage { bearer: String, inbox_name: String, diff --git a/shinkai-libs/shinkai-message-primitives/src/lib.rs b/shinkai-libs/shinkai-message-primitives/src/lib.rs index 439021f33..8e4f568b8 100644 --- a/shinkai-libs/shinkai-message-primitives/src/lib.rs +++ b/shinkai-libs/shinkai-message-primitives/src/lib.rs @@ -1,6 +1,3 @@ - - - pub mod shinkai_message; pub mod schemas; pub mod shinkai_utils; diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs index 8ad6a0fb4..fa56e0c0c 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs @@ -30,7 +30,7 @@ pub struct Job { /// Marks if the job is finished or not pub is_finished: bool, /// Identity of the parent agent. We just use a full identity name for simplicity - pub parent_llm_provider_id: String, + pub parent_agent_or_llm_provider_id: String, /// (Simplified version) What VectorResources the Job has access to when performing vector searches pub scope: MinimalJobScope, /// (Full version) What VectorResources the Job has access to when performing vector searches, including files @@ -77,7 +77,7 @@ impl JobLike for Job { } fn parent_llm_provider_id(&self) -> &str { - &self.parent_llm_provider_id + &self.parent_agent_or_llm_provider_id } fn scope(&self) -> &MinimalJobScope { diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs index 891b892a0..9b572c52d 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/job_config.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use utoipa::ToSchema; -#[derive(Clone, Debug, Serialize, Deserialize, ToSchema)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, ToSchema)] pub struct JobConfig { pub custom_prompt: Option, // pub custom_system_prompt: String @@ -13,4 +13,39 @@ pub struct JobConfig { pub top_p: Option, pub stream: Option, pub other_model_params: Option, -} \ No newline at end of file + // TODO: add ctx_... +} + +impl JobConfig { + /// Merges two JobConfig instances, preferring values from `self` over `other`. + pub fn merge(&self, other: &JobConfig) -> JobConfig { + JobConfig { + // Prefer `self` (provided config) over `other` (agent's config) + custom_prompt: self.custom_prompt.clone().or_else(|| other.custom_prompt.clone()), + temperature: self.temperature.or(other.temperature), + max_tokens: self.max_tokens.or(other.max_tokens), + seed: self.seed.or(other.seed), + top_k: self.top_k.or(other.top_k), + top_p: self.top_p.or(other.top_p), + stream: self.stream.or(other.stream), + other_model_params: self + .other_model_params + .clone() + .or_else(|| other.other_model_params.clone()), + } + } + + /// Creates an empty JobConfig with all fields set to None. + pub fn empty() -> JobConfig { + JobConfig { + custom_prompt: None, + temperature: None, + max_tokens: None, + seed: None, + top_k: None, + top_p: None, + stream: None, + other_model_params: None, + } + } +} diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs new file mode 100644 index 000000000..94ae5fa11 --- /dev/null +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/agent.rs @@ -0,0 +1,19 @@ +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use crate::schemas::{job_config::JobConfig, shinkai_name::ShinkaiName}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] +pub struct Agent { + pub name: String, + pub agent_id: String, + pub full_identity_name: ShinkaiName, + pub llm_provider_id: String, // Connected + // pub instructions: String, // TODO: maybe we can remove on post to custom_prompt -- not super clean but not repetitive + pub ui_description: String, + pub knowledge: Vec, // TODO + pub storage_path: String, // TODO + pub tools: Vec, // Connected + pub debug_mode: bool, // TODO + pub config: Option, // Connected +} diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/common_agent_llm_provider.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/common_agent_llm_provider.rs new file mode 100644 index 000000000..77d348954 --- /dev/null +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/common_agent_llm_provider.rs @@ -0,0 +1,32 @@ +use crate::schemas::shinkai_name::ShinkaiName; + +use super::{agent::Agent, serialized_llm_provider::SerializedLLMProvider}; + +#[derive(Debug, Clone, PartialEq)] +pub enum ProviderOrAgent { + LLMProvider(SerializedLLMProvider), + Agent(Agent), +} + +impl ProviderOrAgent { + pub fn get_id(&self) -> &str { + match self { + ProviderOrAgent::LLMProvider(provider) => &provider.id, + ProviderOrAgent::Agent(agent) => &agent.agent_id, + } + } + + pub fn get_llm_provider_id(&self) -> &str { + match self { + ProviderOrAgent::LLMProvider(provider) => &provider.id, + ProviderOrAgent::Agent(agent) => &agent.llm_provider_id, + } + } + + pub fn get_full_identity_name(&self) -> &ShinkaiName { + match self { + ProviderOrAgent::LLMProvider(provider) => &provider.full_identity_name, + ProviderOrAgent::Agent(agent) => &agent.full_identity_name, + } + } +} diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/customized_agent.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/customized_agent.rs deleted file mode 100644 index 312c46455..000000000 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/customized_agent.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider; -use serde::{Deserialize, Serialize}; - -// Based on the great job by crewai (mostly for for compatibility) https://docs.crewai.com - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct CustomizedAgent { - pub role: String, - pub goal: String, - pub backstory: String, - pub llm: Option, - pub tools: Vec, - pub function_calling_llm: Option, - pub max_iter: Option, - pub max_rpm: Option, - pub max_execution_time: Option, - pub verbose: bool, - pub allow_delegation: bool, - pub step_callback: Option, -} - -impl CustomizedAgent { - #[allow(clippy::too_many_arguments)] - pub fn new( - role: String, - goal: String, - backstory: String, - llm: Option, - tools: Vec, - function_calling_llm: Option, - max_iter: Option, - max_rpm: Option, - max_execution_time: Option, - verbose: bool, - allow_delegation: bool, - step_callback: Option, - ) -> Self { - CustomizedAgent { - role, - goal, - backstory, - llm, - tools, - function_calling_llm, - max_iter, - max_rpm, - max_execution_time, - verbose, - allow_delegation, - step_callback, - } - } -} \ No newline at end of file diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/mod.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/mod.rs index 1ffaee575..dbdade8a7 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/mod.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/mod.rs @@ -1,2 +1,3 @@ pub mod serialized_llm_provider; -pub mod customized_agent; \ No newline at end of file +pub mod agent; +pub mod common_agent_llm_provider; \ No newline at end of file diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/serialized_llm_provider.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/serialized_llm_provider.rs index cf2efff76..f7f07e452 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/serialized_llm_provider.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/llm_providers/serialized_llm_provider.rs @@ -10,13 +10,9 @@ use utoipa::ToSchema; pub struct SerializedLLMProvider { pub id: String, pub full_identity_name: ShinkaiName, - pub perform_locally: bool, // TODO: Remove this and update libs pub external_url: Option, pub api_key: Option, pub model: LLMProviderInterface, - pub toolkit_permissions: Vec, - pub storage_bucket_permissions: Vec, - pub allowed_message_senders: Vec, } impl SerializedLLMProvider { @@ -53,15 +49,11 @@ impl SerializedLLMProvider { SerializedLLMProvider { id: "mock_agent".to_string(), full_identity_name: ShinkaiName::new("@@test.shinkai/main/agent/mock_agent".to_string()).unwrap(), - perform_locally: false, external_url: Some("https://api.example.com".to_string()), api_key: Some("mockapikey".to_string()), model: LLMProviderInterface::OpenAI(OpenAI { model_type: "gpt-4o-mini".to_string(), }), - toolkit_permissions: vec![], - storage_bucket_permissions: vec![], - allowed_message_senders: vec![], } } } diff --git a/shinkai-libs/shinkai-message-pyo3/src/shinkai_pyo3_utils/pyo3_serialized_llm_provider.rs b/shinkai-libs/shinkai-message-pyo3/src/shinkai_pyo3_utils/pyo3_serialized_llm_provider.rs index 091caf897..e8fc16ee6 100644 --- a/shinkai-libs/shinkai-message-pyo3/src/shinkai_pyo3_utils/pyo3_serialized_llm_provider.rs +++ b/shinkai-libs/shinkai-message-pyo3/src/shinkai_pyo3_utils/pyo3_serialized_llm_provider.rs @@ -31,10 +31,6 @@ impl PySerializedLLMProvider { .and_then(|k| k.get_item("id").ok().flatten()) .and_then(|v| v.extract::().ok()) .unwrap_or_else(|| String::new()); - let perform_locally = kwargs - .and_then(|k| k.get_item("perform_locally").ok().flatten()) - .and_then(|v| v.extract::().ok()) - .unwrap_or_else(|| false); let external_url = kwargs .and_then(|k| k.get_item("external_url").ok().flatten()) .and_then(|v| v.extract::().ok()); @@ -46,30 +42,14 @@ impl PySerializedLLMProvider { .and_then(|v| v.extract::().ok()) .map(|py_model| py_model.inner) .ok_or_else(|| PyErr::new::("model is required"))?; - let toolkit_permissions = kwargs - .and_then(|k| k.get_item("toolkit_permissions").ok().flatten()) - .and_then(|v| v.extract::>().ok()) - .unwrap_or_else(|| Vec::new()); - let storage_bucket_permissions = kwargs - .and_then(|k| k.get_item("storage_bucket_permissions").ok().flatten()) - .and_then(|v| v.extract::>().ok()) - .unwrap_or_else(|| Vec::new()); - let allowed_message_senders = kwargs - .and_then(|k| k.get_item("allowed_message_senders").ok().flatten()) - .and_then(|v| v.extract::>().ok()) - .unwrap_or_else(|| Vec::new()); Ok(Self { inner: SerializedLLMProvider { id, full_identity_name, - perform_locally, external_url, api_key, model, - toolkit_permissions, - storage_bucket_permissions, - allowed_message_senders, }, }) } @@ -91,13 +71,9 @@ impl PySerializedLLMProvider { inner: SerializedLLMProvider { id, full_identity_name, - perform_locally: false, external_url: Some(external_url), api_key, model, - toolkit_permissions: Vec::new(), - storage_bucket_permissions: Vec::new(), - allowed_message_senders: Vec::new(), }, }) } diff --git a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_llm_provider_wrapper.rs b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_llm_provider_wrapper.rs index f704361db..8b84338bf 100644 --- a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_llm_provider_wrapper.rs +++ b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_llm_provider_wrapper.rs @@ -19,13 +19,9 @@ pub trait SerializedLLMProviderJsValueConversion { fn from_strings( id: String, full_identity_name: String, - perform_locally: String, external_url: String, api_key: String, model: String, - toolkit_permissions: String, - storage_bucket_permissions: String, - allowed_message_senders: String, ) -> Result where Self: Sized; @@ -52,18 +48,11 @@ impl SerializedLLMProviderJsValueConversion for SerializedLLMProvider { fn from_strings( id: String, full_identity_name: String, - perform_locally: String, external_url: String, api_key: String, model: String, - toolkit_permissions: String, - storage_bucket_permissions: String, - allowed_message_senders: String, ) -> Result { // Convert the strings to the appropriate types - let perform_locally = perform_locally - .parse::() - .map_err(|_| JsValue::from_str("Invalid perform_locally"))?; let external_url = if external_url.is_empty() { None } else { @@ -73,32 +62,13 @@ impl SerializedLLMProviderJsValueConversion for SerializedLLMProvider { let model = model .parse::() .map_err(|_| JsValue::from_str("Invalid model"))?; - let toolkit_permissions = if toolkit_permissions.is_empty() { - Vec::new() - } else { - toolkit_permissions.split(',').map(|s| s.to_string()).collect() - }; - let storage_bucket_permissions = if storage_bucket_permissions.is_empty() { - Vec::new() - } else { - storage_bucket_permissions.split(',').map(|s| s.to_string()).collect() - }; - let allowed_message_senders = if allowed_message_senders.is_empty() { - Vec::new() - } else { - allowed_message_senders.split(',').map(|s| s.to_string()).collect() - }; Ok(SerializedLLMProvider { id, full_identity_name: ShinkaiName::new(full_identity_name)?, - perform_locally, external_url, api_key, model, - toolkit_permissions, - storage_bucket_permissions, - allowed_message_senders, }) } } @@ -124,24 +94,16 @@ impl SerializedLLMProviderWrapper { pub fn from_strings( id: String, full_identity_name: String, - perform_locally: String, external_url: String, api_key: String, model: String, - toolkit_permissions: String, - storage_bucket_permissions: String, - allowed_message_senders: String, ) -> Result { let inner = SerializedLLMProvider::from_strings( id, full_identity_name, - perform_locally, external_url, api_key, model, - toolkit_permissions, - storage_bucket_permissions, - allowed_message_senders, )?; Ok(SerializedLLMProviderWrapper { inner }) } diff --git a/shinkai-libs/shinkai-message-wasm/tests/serialized_llm_provider_conversion_tests.rs b/shinkai-libs/shinkai-message-wasm/tests/serialized_llm_provider_conversion_tests.rs index 0ef58cb06..3ad5c2f9f 100644 --- a/shinkai-libs/shinkai-message-wasm/tests/serialized_llm_provider_conversion_tests.rs +++ b/shinkai-libs/shinkai-message-wasm/tests/serialized_llm_provider_conversion_tests.rs @@ -17,13 +17,9 @@ mod tests { let serialized_llm_provider_wrapper = SerializedLLMProviderWrapper::from_strings( "test_agent".to_string(), "@@node.shinkai/main/agent/test_agent".to_string(), - "false".to_string(), "http://example.com".to_string(), "123456".to_string(), "openai:gpt-3.5-turbo-1106".to_string(), - "permission1,permission2".to_string(), - "bucket1,bucket2".to_string(), - "sender1,sender2".to_string(), ) .unwrap(); @@ -37,7 +33,6 @@ mod tests { agent.full_identity_name.to_string(), "@@node.shinkai/main/agent/test_agent" ); - assert_eq!(agent.perform_locally, false); assert_eq!(agent.external_url, Some("http://example.com".to_string())); assert_eq!(agent.api_key, Some("123456".to_string())); assert_eq!( @@ -46,18 +41,6 @@ mod tests { model_type: "gpt-3.5-turbo-1106".to_string() }) ); - assert_eq!( - agent.toolkit_permissions, - vec!["permission1".to_string(), "permission2".to_string()] - ); - assert_eq!( - agent.storage_bucket_permissions, - vec!["bucket1".to_string(), "bucket2".to_string()] - ); - assert_eq!( - agent.allowed_message_senders, - vec!["sender1".to_string(), "sender2".to_string()] - ); } #[cfg(target_arch = "wasm32")] @@ -68,13 +51,9 @@ mod tests { let serialized_llm_provider_wrapper = SerializedLLMProviderWrapper::from_strings( "test_agent".to_string(), "@@node.shinkai/main/agent/test_agent".to_string(), - "false".to_string(), "http://example.com".to_string(), "123456".to_string(), "openai:gpt-3.5-turbo-1106".to_string(), - "permission1,permission2".to_string(), - "bucket1,bucket2".to_string(), - "sender1,sender2".to_string(), ) .unwrap(); @@ -100,7 +79,6 @@ mod tests { agent.full_identity_name.to_string(), "@@node.shinkai/main/agent/test_agent" ); - assert_eq!(agent.perform_locally, false); assert_eq!(agent.external_url, Some("http://example.com".to_string())); assert_eq!(agent.api_key, Some("123456".to_string())); assert_eq!( @@ -109,17 +87,5 @@ mod tests { model_type: "gpt-3.5-turbo-1106".to_string() }) ); - assert_eq!( - agent.toolkit_permissions, - vec!["permission1".to_string(), "permission2".to_string()] - ); - assert_eq!( - agent.storage_bucket_permissions, - vec!["bucket1".to_string(), "bucket2".to_string()] - ); - assert_eq!( - agent.allowed_message_senders, - vec!["sender1".to_string(), "sender2".to_string()] - ); } }