Skip to content

Commit

Permalink
Nico/add ollama and shinkai backend (#160)
Browse files Browse the repository at this point in the history
* cp

* more fixes

* node working

* update libs
  • Loading branch information
nicarq authored Nov 30, 2023
1 parent e539110 commit e01788d
Show file tree
Hide file tree
Showing 26 changed files with 942 additions and 300 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-message-primitives/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-message-primitives/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_message_primitives"
version = "0.1.0"
version = "0.1.1"
edition = "2018"
authors = ["Nico Arqueros <[email protected]>"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,27 @@ pub enum AgentLLMInterface {
OpenAI(OpenAI),
#[serde(rename = "genericapi")]
GenericAPI(GenericAPI),
#[serde(rename = "ollama")]
Ollama(Ollama),
#[serde(rename = "shinkai-backend")]
ShinkaiBackend(ShinkaiBackend),
#[serde(rename = "local-llm")]
LocalLLM(LocalLLM),
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct LocalLLM {}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Ollama {
pub model_type: String,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ShinkaiBackend {
pub model_type: String,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct OpenAI {
pub model_type: String,
Expand All @@ -50,8 +64,13 @@ impl FromStr for AgentLLMInterface {
} else if s.starts_with("genericapi:") {
let model_type = s.strip_prefix("genericapi:").unwrap_or("").to_string();
Ok(AgentLLMInterface::GenericAPI(GenericAPI { model_type }))
} else if s.starts_with("ollama:") {
let model_type = s.strip_prefix("ollama:").unwrap_or("").to_string();
Ok(AgentLLMInterface::Ollama(Ollama { model_type }))
} else if s.starts_with("shinkai-backend:") {
let model_type = s.strip_prefix("shinkai-backend:").unwrap_or("").to_string();
Ok(AgentLLMInterface::ShinkaiBackend(ShinkaiBackend { model_type }))
} else {
// TODO: nothing else for now
Err(())
}
}
Expand Down
4 changes: 2 additions & 2 deletions shinkai-libs/shinkai-message-pyo3/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-message-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_message_pyo3"
version = "0.1.4"
version = "0.1.5"
edition = "2018"
authors = ["Nico Arqueros <[email protected]>"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use pyo3::types::PyDict;
use shinkai_message_primitives::schemas::agents::serialized_agent::AgentLLMInterface;
use shinkai_message_primitives::schemas::agents::serialized_agent::GenericAPI;
use shinkai_message_primitives::schemas::agents::serialized_agent::LocalLLM;
use shinkai_message_primitives::schemas::agents::serialized_agent::Ollama;
use shinkai_message_primitives::schemas::agents::serialized_agent::OpenAI;
use shinkai_message_primitives::schemas::agents::serialized_agent::ShinkaiBackend;

#[pyclass]
#[derive(Debug, Clone)]
Expand All @@ -25,8 +27,17 @@ impl PyAgentLLMInterface {
Ok(Self {
inner: AgentLLMInterface::GenericAPI(GenericAPI { model_type }),
})
}
else {
} else if s.starts_with("ollama:") {
let model_type = s.strip_prefix("ollama:").unwrap_or("").to_string();
Ok(Self {
inner: AgentLLMInterface::Ollama(Ollama { model_type }),
})
} else if s.starts_with("shinkai-backend:") {
let model_type = s.strip_prefix("shinkai-backend:").unwrap_or("").to_string();
Ok(Self {
inner: AgentLLMInterface::ShinkaiBackend(ShinkaiBackend { model_type }),
})
} else {
Ok(Self {
inner: AgentLLMInterface::LocalLLM(LocalLLM {}),
})
Expand Down Expand Up @@ -64,6 +75,8 @@ impl PyAgentLLMInterface {
match &self.inner {
AgentLLMInterface::OpenAI(open_ai) => Ok(format!("openai:{}", open_ai.model_type)),
AgentLLMInterface::GenericAPI(generic_ai) => Ok(format!("genericapi:{}", generic_ai.model_type)),
AgentLLMInterface::Ollama(ollama) => Ok(format!("ollama:{}", ollama.model_type)),
AgentLLMInterface::ShinkaiBackend(shinkai_backend) => Ok(format!("shinkai-backend:{}", shinkai_backend.model_type)),
AgentLLMInterface::LocalLLM(_) => Ok("LocalLLM".to_string()),
}
}
Expand Down
4 changes: 2 additions & 2 deletions shinkai-libs/shinkai-message-wasm/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-message-wasm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_message_wasm"
version = "0.1.6"
version = "0.1.7"
edition = "2018"
authors = ["Nico Arqueros <[email protected]>"]

Expand Down
2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-message-wasm/pkg/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"collaborators": [
"Nico Arqueros <[email protected]>"
],
"version": "0.1.6",
"version": "0.1.7",
"files": [
"shinkai_message_wasm_bg.wasm",
"shinkai_message_wasm.js",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3053,13 +3053,13 @@ export function __wbindgen_string_new(arg0, arg1) {
return addHeapObject(ret);
};

export function __wbindgen_error_new(arg0, arg1) {
const ret = new Error(getStringFromWasm0(arg0, arg1));
export function __wbindgen_object_clone_ref(arg0) {
const ret = getObject(arg0);
return addHeapObject(ret);
};

export function __wbindgen_object_clone_ref(arg0) {
const ret = getObject(arg0);
export function __wbindgen_error_new(arg0, arg1) {
const ret = new Error(getStringFromWasm0(arg0, arg1));
return addHeapObject(ret);
};

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,10 @@ export function shinkaimessagebuilderwrapper_job_message(a: number, b: number, c
export function shinkaimessagebuilderwrapper_terminate_message(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number): void;
export function shinkaimessagebuilderwrapper_error_message(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number): void;
export function shinkaimessagebuilderwrapper_get_last_unread_messages_from_inbox(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number, p: number, q: number, r: number, s: number, t: number): void;
export function __wbg_inboxnamewrapper_free(a: number): void;
export function inboxnamewrapper_new(a: number, b: number): void;
export function inboxnamewrapper_to_string(a: number, b: number): void;
export function inboxnamewrapper_get_value(a: number): number;
export function inboxnamewrapper_get_is_e2e(a: number): number;
export function inboxnamewrapper_get_identities(a: number, b: number): void;
export function inboxnamewrapper_get_unique_id(a: number): number;
export function inboxnamewrapper_to_jsvalue(a: number, b: number): void;
export function inboxnamewrapper_to_json_str(a: number, b: number): void;
export function inboxnamewrapper_get_regular_inbox_name_from_params(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number): void;
export function inboxnamewrapper_get_job_inbox_name_from_params(a: number, b: number, c: number): void;
export function inboxnamewrapper_get_inner(a: number): number;
export function shinkaitime_generateTimeNow(a: number): void;
export function shinkaitime_generateTimeInFutureWithSecs(a: number, b: number): void;
export function shinkaitime_generateSpecificTime(a: number, b: number, c: number, d: number, e: number, f: number, g: number): void;
export function __wbg_shinkaitime_free(a: number): void;
export function __wbg_serializedagentwrapper_free(a: number): void;
export function serializedagentwrapper_fromStrings(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number, p: number, q: number, r: number, s: number): void;
export function serializedagentwrapper_fromJsValue(a: number, b: number): void;
Expand All @@ -56,46 +48,6 @@ export function serializedagentwrapper_from_json_str(a: number, b: number, c: nu
export function serializedagentwrapper_inner(a: number, b: number): void;
export function serializedagentwrapper_new(a: number, b: number): void;
export function serializedagentwrapper_to_jsvalue(a: number, b: number): void;
export function shinkaitime_generateTimeNow(a: number): void;
export function shinkaitime_generateTimeInFutureWithSecs(a: number, b: number): void;
export function shinkaitime_generateSpecificTime(a: number, b: number, c: number, d: number, e: number, f: number, g: number): void;
export function __wbg_shinkaitime_free(a: number): void;
export function __wbg_jobscopewrapper_free(a: number): void;
export function jobscopewrapper_new(a: number, b: number, c: number): void;
export function jobscopewrapper_to_jsvalue(a: number, b: number): void;
export function jobscopewrapper_to_json_str(a: number, b: number): void;
export function __wbg_jobcreationwrapper_free(a: number): void;
export function jobcreationwrapper_new(a: number, b: number): void;
export function jobcreationwrapper_to_jsvalue(a: number, b: number): void;
export function jobcreationwrapper_to_json_str(a: number, b: number): void;
export function jobcreationwrapper_get_scope(a: number, b: number): void;
export function jobcreationwrapper_from_json_str(a: number, b: number, c: number): void;
export function jobcreationwrapper_from_jsvalue(a: number, b: number): void;
export function jobcreationwrapper_empty(a: number): void;
export function __wbg_jobmessagewrapper_free(a: number): void;
export function jobmessagewrapper_new(a: number, b: number, c: number, d: number): void;
export function jobmessagewrapper_to_jsvalue(a: number, b: number): void;
export function jobmessagewrapper_to_json_str(a: number, b: number): void;
export function jobmessagewrapper_from_json_str(a: number, b: number, c: number): void;
export function jobmessagewrapper_from_jsvalue(a: number, b: number): void;
export function jobmessagewrapper_fromStrings(a: number, b: number, c: number, d: number, e: number, f: number): number;
export function __wbg_shinkainamewrapper_free(a: number): void;
export function shinkainamewrapper_new(a: number, b: number): void;
export function shinkainamewrapper_get_full_name(a: number): number;
export function shinkainamewrapper_get_node_name(a: number): number;
export function shinkainamewrapper_get_profile_name(a: number): number;
export function shinkainamewrapper_get_subidentity_type(a: number): number;
export function shinkainamewrapper_get_subidentity_name(a: number): number;
export function shinkainamewrapper_to_jsvalue(a: number, b: number): void;
export function shinkainamewrapper_to_json_str(a: number, b: number): void;
export function shinkainamewrapper_extract_profile(a: number, b: number): void;
export function shinkainamewrapper_extract_node(a: number): number;
export function __wbg_wasmencryptionmethod_free(a: number): void;
export function wasmencryptionmethod_new(a: number, b: number): number;
export function wasmencryptionmethod_as_str(a: number, b: number): void;
export function wasmencryptionmethod_DiffieHellmanChaChaPoly1305(a: number): void;
export function wasmencryptionmethod_None(a: number): void;
export function convert_encryption_sk_string_to_encryption_pk_string(a: number, b: number, c: number): void;
export function __wbg_shinkaimessagewrapper_free(a: number): void;
export function shinkaimessagewrapper_message_body(a: number, b: number): void;
export function shinkaimessagewrapper_set_message_body(a: number, b: number, c: number): void;
Expand All @@ -115,7 +67,55 @@ export function shinkaimessagewrapper_calculate_blake3_hash_with_empty_outer_sig
export function shinkaimessagewrapper_calculate_blake3_hash_with_empty_inner_signature(a: number, b: number): void;
export function shinkaimessagewrapper_generate_time_now(a: number): void;
export function shinkaimessagewrapper_new(a: number, b: number): void;
export function __wbg_shinkainamewrapper_free(a: number): void;
export function shinkainamewrapper_new(a: number, b: number): void;
export function shinkainamewrapper_get_full_name(a: number): number;
export function shinkainamewrapper_get_node_name(a: number): number;
export function shinkainamewrapper_get_profile_name(a: number): number;
export function shinkainamewrapper_get_subidentity_type(a: number): number;
export function shinkainamewrapper_get_subidentity_name(a: number): number;
export function shinkainamewrapper_to_jsvalue(a: number, b: number): void;
export function shinkainamewrapper_to_json_str(a: number, b: number): void;
export function shinkainamewrapper_extract_profile(a: number, b: number): void;
export function shinkainamewrapper_extract_node(a: number): number;
export function __wbg_wasmencryptionmethod_free(a: number): void;
export function wasmencryptionmethod_new(a: number, b: number): number;
export function wasmencryptionmethod_as_str(a: number, b: number): void;
export function wasmencryptionmethod_DiffieHellmanChaChaPoly1305(a: number): void;
export function wasmencryptionmethod_None(a: number): void;
export function convert_encryption_sk_string_to_encryption_pk_string(a: number, b: number, c: number): void;
export function calculate_blake3_hash(a: number, b: number, c: number): void;
export function __wbg_jobscopewrapper_free(a: number): void;
export function jobscopewrapper_new(a: number, b: number, c: number): void;
export function jobscopewrapper_to_jsvalue(a: number, b: number): void;
export function jobscopewrapper_to_json_str(a: number, b: number): void;
export function __wbg_jobcreationwrapper_free(a: number): void;
export function jobcreationwrapper_new(a: number, b: number): void;
export function jobcreationwrapper_to_jsvalue(a: number, b: number): void;
export function jobcreationwrapper_to_json_str(a: number, b: number): void;
export function jobcreationwrapper_get_scope(a: number, b: number): void;
export function jobcreationwrapper_from_json_str(a: number, b: number, c: number): void;
export function jobcreationwrapper_from_jsvalue(a: number, b: number): void;
export function jobcreationwrapper_empty(a: number): void;
export function __wbg_jobmessagewrapper_free(a: number): void;
export function jobmessagewrapper_new(a: number, b: number, c: number, d: number): void;
export function jobmessagewrapper_to_jsvalue(a: number, b: number): void;
export function jobmessagewrapper_to_json_str(a: number, b: number): void;
export function jobmessagewrapper_from_json_str(a: number, b: number, c: number): void;
export function jobmessagewrapper_from_jsvalue(a: number, b: number): void;
export function jobmessagewrapper_fromStrings(a: number, b: number, c: number, d: number, e: number, f: number): number;
export function __wbg_inboxnamewrapper_free(a: number): void;
export function inboxnamewrapper_new(a: number, b: number): void;
export function inboxnamewrapper_to_string(a: number, b: number): void;
export function inboxnamewrapper_get_value(a: number): number;
export function inboxnamewrapper_get_is_e2e(a: number): number;
export function inboxnamewrapper_get_identities(a: number, b: number): void;
export function inboxnamewrapper_get_unique_id(a: number): number;
export function inboxnamewrapper_to_jsvalue(a: number, b: number): void;
export function inboxnamewrapper_to_json_str(a: number, b: number): void;
export function inboxnamewrapper_get_regular_inbox_name_from_params(a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number): void;
export function inboxnamewrapper_get_job_inbox_name_from_params(a: number, b: number, c: number): void;
export function inboxnamewrapper_get_inner(a: number): number;
export function __wbindgen_malloc(a: number, b: number): number;
export function __wbindgen_realloc(a: number, b: number, c: number, d: number): number;
export function __wbindgen_add_to_stack_pointer(a: number): number;
Expand Down
10 changes: 10 additions & 0 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ impl Agent {
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::Ollama(ollama) => {
ollama
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::ShinkaiBackend(shinkai_backend) => {
shinkai_backend
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::LocalLLM(local_llm) => {
self.inference_locally(prompt.generate_single_output_string()?).await
}
Expand Down
14 changes: 13 additions & 1 deletion src/agent/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::db::db_errors::ShinkaiDBError;
use crate::{db::db_errors::ShinkaiDBError, managers::agents_capabilities_manager::AgentsCapabilitiesManagerError};
use anyhow::Error as AnyhowError;
use shinkai_message_primitives::{
schemas::{inbox_name::InboxNameError, shinkai_name::ShinkaiNameError},
Expand Down Expand Up @@ -45,6 +45,9 @@ pub enum AgentError {
InvalidCronExecutionChainStage(String),
AnyhowError(AnyhowError),
AgentMissingCapabilities(String),
UnexpectedPromptResult(String),
AgentsCapabilitiesManagerError(AgentsCapabilitiesManagerError),
UnexpectedPromptResultVariant(String)
}

impl fmt::Display for AgentError {
Expand Down Expand Up @@ -98,6 +101,9 @@ impl fmt::Display for AgentError {
AgentError::InvalidCronExecutionChainStage(s) => write!(f, "Invalid cron execution chain stage: {}", s),
AgentError::AnyhowError(err) => write!(f, "{}", err),
AgentError::AgentMissingCapabilities(s) => write!(f, "Agent is missing capabilities: {}", s),
AgentError::UnexpectedPromptResult(s) => write!(f, "Unexpected prompt result: {}", s),
AgentError::AgentsCapabilitiesManagerError(err) => write!(f, "AgentsCapabilitiesManager error: {}", err),
AgentError::UnexpectedPromptResultVariant(s) => write!(f, "Unexpected prompt result variant: {}", s),
}
}
}
Expand Down Expand Up @@ -172,3 +178,9 @@ impl From<InboxNameError> for AgentError {
AgentError::InboxNameError(error)
}
}

impl From<AgentsCapabilitiesManagerError> for AgentError {
fn from(error: AgentsCapabilitiesManagerError) -> Self {
AgentError::AgentsCapabilitiesManagerError(error)
}
}
2 changes: 1 addition & 1 deletion src/agent/execution/job_prompts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::super::{error::AgentError, providers::openai::OpenAIApiMessage};
use super::super::{error::AgentError};
use crate::{agent::job::JobStepResult, tools::router::ShinkaiTool};
use futures::stream::ForEach;
use lazy_static::lazy_static;
Expand Down
Loading

0 comments on commit e01788d

Please sign in to comment.