Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new (basic) agents #631

Merged
merged 12 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions docs/openapi/general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -763,16 +763,8 @@ components:
required:
- id
- full_identity_name
- perform_locally
- model
- toolkit_permissions
- storage_bucket_permissions
- allowed_message_senders
properties:
allowed_message_senders:
type: array
items:
type: string
api_key:
type: string
nullable: true
Expand All @@ -785,16 +777,6 @@ components:
type: string
model:
$ref: '#/components/schemas/LLMProviderInterface'
perform_locally:
type: boolean
storage_bucket_permissions:
type: array
items:
type: string
toolkit_permissions:
type: array
items:
type: string
ShinkaiBackend:
type: object
required:
Expand Down
18 changes: 0 additions & 18 deletions docs/openapi/jobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1412,16 +1412,8 @@ components:
required:
- id
- full_identity_name
- perform_locally
- model
- toolkit_permissions
- storage_bucket_permissions
- allowed_message_senders
properties:
allowed_message_senders:
type: array
items:
type: string
api_key:
type: string
nullable: true
Expand All @@ -1434,16 +1426,6 @@ components:
type: string
model:
$ref: '#/components/schemas/LLMProviderInterface'
perform_locally:
type: boolean
storage_bucket_permissions:
type: array
items:
type: string
toolkit_permissions:
type: array
items:
type: string
SheetJobAction:
type: object
required:
Expand Down
3 changes: 3 additions & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub enum LLMProviderError {
ToolNotFound(String),
ToolRetrievalError(String),
ToolSearchError(String),
AgentNotFound(String),
}

impl fmt::Display for LLMProviderError {
Expand Down Expand Up @@ -176,6 +177,7 @@ impl fmt::Display for LLMProviderError {
LLMProviderError::ToolNotFound(s) => write!(f, "Tool not found: {}", s),
LLMProviderError::ToolRetrievalError(s) => write!(f, "Tool retrieval error: {}", s),
LLMProviderError::ToolSearchError(s) => write!(f, "Tool search error: {}", s),
LLMProviderError::AgentNotFound(s) => write!(f, "Agent not found: {}", s),
}
}
}
Expand Down Expand Up @@ -255,6 +257,7 @@ impl LLMProviderError {
LLMProviderError::ToolNotFound(_) => "ToolNotFound",
LLMProviderError::ToolRetrievalError(_) => "ToolRetrievalError",
LLMProviderError::ToolSearchError(_) => "ToolSearchError",
LLMProviderError::AgentNotFound(_) => "AgentNotFound",
};

let error_message = format!("{}", self);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use regex::Regex;
use shinkai_baml::baml_builder::{BamlConfig, ClientConfig, GeneratorConfig};
use shinkai_dsl::dsl_schemas::Workflow;
use shinkai_message_primitives::{
schemas::{inbox_name::InboxName, job::JobLike},
schemas::{inbox_name::InboxName, job::JobLike, llm_providers::common_agent_llm_provider::ProviderOrAgent},
shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption},
};
use shinkai_sqlite::logger::{WorkflowLogEntry, WorkflowLogEntryStatus};
Expand Down Expand Up @@ -357,6 +357,7 @@ impl AsyncFunction for InferenceFunction {
},
None, // this is the config
self.context.llm_stopper().clone(),
self.context.db().clone(),
)
.await
.map_err(|e| WorkflowError::ExecutionError(e.to_string()))?;
Expand Down Expand Up @@ -460,6 +461,7 @@ impl AsyncFunction for OpinionatedInferenceFunction {
},
None, // this is the config
self.context.llm_stopper().clone(),
self.context.db().clone(),
)
.await
.map_err(|e| WorkflowError::ExecutionError(e.to_string()))?;
Expand Down Expand Up @@ -507,11 +509,26 @@ impl AsyncFunction for BamlInference {

// TODO: do we need the job for something?
// let full_job = self.context.full_job();
let llm_provider = self.context.agent();
let llm_or_agent_provider = self.context.agent();

let generator_config = GeneratorConfig::default();

// TODO: add support for other providers
let llm_provider = match llm_or_agent_provider {
ProviderOrAgent::LLMProvider(provider) => provider,
ProviderOrAgent::Agent(agent) => {
let llm_provider_id = agent.llm_provider_id.clone();
let profile = agent.full_identity_name.clone();
let provider = self.context
.db()
.get_llm_provider(&llm_provider_id, &profile)
.map_err(|e| WorkflowError::ExecutionError(e.to_string()))?;
&provider
.as_ref()
.ok_or_else(|| WorkflowError::ExecutionError("LLM provider not found".to_string()))?
.clone()
}
};

let base_url = llm_provider.external_url.clone().unwrap_or_default();
let base_url = if base_url == "http://localhost:11434" || base_url == "http://localhost:11435" {
format!("{}/v1", base_url)
Expand Down Expand Up @@ -619,7 +636,15 @@ impl AsyncFunction for MultiInferenceFunction {

let mut responses = Vec::new();
let agent = self.context.agent();
let max_tokens = ModelCapabilitiesManager::get_max_input_tokens(&agent.model);
let max_tokens = ModelCapabilitiesManager::get_max_input_tokens_for_provider_or_agent(
agent.clone(),
self.context.db().clone(),
);

let max_tokens = match max_tokens {
Some(tokens) => tokens,
None => return Err(WorkflowError::ExecutionError("Max tokens not found".to_string())),
};

for text in split_texts.iter() {
let inference_args = vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ pub fn search_embeddings_in_job_scope(
#[cfg(test)]
mod tests {
use shinkai_db::db::ShinkaiDB;
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{
LLMProviderInterface, OpenAI, SerializedLLMProvider,
};
Expand Down Expand Up @@ -696,13 +697,9 @@ mod tests {
let agent = SerializedLLMProvider {
id: "test_agent_id".to_string(),
full_identity_name: agent_name,
perform_locally: false,
external_url: Some("https://api.openai.com".to_string()),
api_key: Some("mockapikey".to_string()),
model: LLMProviderInterface::OpenAI(open_ai),
toolkit_permissions: vec![],
storage_bucket_permissions: vec![],
allowed_message_senders: vec![],
};
let image_files = HashMap::new();

Expand All @@ -717,7 +714,7 @@ mod tests {
},
None,
image_files,
agent,
ProviderOrAgent::LLMProvider(agent),
HashMap::new(),
generator,
ShinkaiName::default_testnet_localhost(),
Expand Down Expand Up @@ -832,13 +829,9 @@ mod tests {
let agent = SerializedLLMProvider {
id: "test_agent_id_with_query".to_string(),
full_identity_name: agent_name,
perform_locally: false,
external_url: Some("https://api.openai.com".to_string()),
api_key: Some("mockapikey".to_string()),
model: LLMProviderInterface::OpenAI(open_ai),
toolkit_permissions: vec![],
storage_bucket_permissions: vec![],
allowed_message_senders: vec![],
};
let image_files = HashMap::new();

Expand All @@ -853,7 +846,7 @@ mod tests {
},
None,
image_files,
agent,
ProviderOrAgent::LLMProvider(agent),
HashMap::new(),
generator,
ShinkaiName::default_testnet_localhost(),
Expand Down Expand Up @@ -983,13 +976,9 @@ mod tests {
let agent = SerializedLLMProvider {
id: "test_agent_id".to_string(),
full_identity_name: agent_name,
perform_locally: false,
external_url: Some("https://api.openai.com".to_string()),
api_key: Some("mockapikey".to_string()),
model: LLMProviderInterface::OpenAI(open_ai),
toolkit_permissions: vec![],
storage_bucket_permissions: vec![],
allowed_message_senders: vec![],
};
let image_files = HashMap::new();

Expand All @@ -1004,7 +993,7 @@ mod tests {
},
None,
image_files,
agent,
ProviderOrAgent::LLMProvider(agent),
HashMap::new(),
generator,
ShinkaiName::default_testnet_localhost(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ pub fn split_text_for_llm(
.clone();

let agent = context.agent();
let max_tokens = ModelCapabilitiesManager::get_max_input_tokens(&agent.model);
let max_tokens = ModelCapabilitiesManager::get_max_input_tokens_for_provider_or_agent(agent.clone(), context.db().clone());

let max_tokens = match max_tokens {
Some(tokens) => tokens,
None => return Err(WorkflowError::ExecutionError("Max tokens not found".to_string())),
};

let mut result = Vec::new();
let current_text = input1.clone();
Expand Down
Loading
Loading