Skip to content

Commit

Permalink
fix: doc chat & api key v2 (#532)
Browse files Browse the repository at this point in the history
* fix

* update api_v2_key

* make ollama use embedding batches

* api v2 always active

* Revert "make ollama use embedding batches"

This reverts commit c51e83a.

* bump: shinkai node v0.7.31

* fix test

---------

Co-authored-by: paulclindo <[email protected]>
  • Loading branch information
nicarq and paulclindo authored Aug 28, 2024
1 parent be9f0f1 commit 53f1117
Show file tree
Hide file tree
Showing 25 changed files with 194 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "shinkai_node"
version = "0.7.30"
version = "0.7.31"
edition = "2021"
authors.workspace = true
# this causes `cargo run` in the workspace root to run this package
Expand Down
2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/src/db/db_inbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ impl ShinkaiDB {
}
shinkai_log(
ShinkaiLogOption::Api,
ShinkaiLogLevel::Info,
ShinkaiLogLevel::Debug,
&format!("Inboxes: {}", inboxes.join(", ")),
);
Ok(inboxes)
Expand Down
16 changes: 16 additions & 0 deletions shinkai-bin/shinkai-node/src/db/db_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ impl ShinkaiDB {
.ok_or(ShinkaiDBError::ShinkaiNameLacksProfile)
}

/// Required for intra-communications between node UI and node
pub fn read_api_v2_key(&self) -> Result<Option<String>, ShinkaiDBError> {
let cf = self.get_cf_handle(Topic::NodeAndUsers)?;
match self.db.get_cf(cf, b"api_v2_key") {
Ok(Some(value)) => Ok(Some(String::from_utf8(value).map_err(|_| ShinkaiDBError::InvalidData)?)),
Ok(None) => Ok(None),
Err(_) => Err(ShinkaiDBError::FailedFetchingValue),
}
}

/// Sets the api_v2_key value
pub fn set_api_v2_key(&self, key: &str) -> Result<(), Error> {
let cf = self.get_cf_handle(Topic::NodeAndUsers).unwrap();
self.db.put_cf(cf, b"api_v2_key", key.as_bytes())
}

/// Returns the first half of the blake3 hash of the folder name value
pub fn user_profile_to_half_hash(profile: ShinkaiName) -> String {
let full_hash = blake3::hash(profile.full_name.as_bytes()).to_hex().to_string();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ impl JobPromptGenerator {
}
}

// Add the user question and the preference prompt for the answer
if !user_message.is_empty() {
let user_prompt = custom_user_prompt.unwrap_or_default();
let content = if user_prompt.is_empty() {
user_message.clone()
} else {
format!("{}\n {}", user_message, user_prompt)
};
prompt.add_content(content, SubPromptType::User, 100);
}

// Parses the retrieved nodes as individual sub-prompts, to support priority pruning
// and also grouping i.e. instead of having 100 tiny messages, we have a message with the chunks grouped
{
Expand All @@ -82,6 +71,17 @@ impl JobPromptGenerator {
}
}

// Add the user question and the preference prompt for the answer
if !user_message.is_empty() {
let user_prompt = custom_user_prompt.unwrap_or_default();
let content = if user_prompt.is_empty() {
user_message.clone()
} else {
format!("{}\n {}", user_message, user_prompt)
};
prompt.add_content(content, SubPromptType::UserLastMessage, 100);
}

// If function_call exists, it means that the LLM requested a function call and we need to send the response back
if let Some(function_call) = function_call {
// We add the assistant request to the prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ impl Prompt {
// Accumulator for ExtraContext content
let mut extra_context_content = String::new();
let mut processing_extra_context = false;
let mut last_user_message: Option<String> = None;

for sub_prompt in &self.sub_prompts {
match sub_prompt {
Expand Down Expand Up @@ -340,25 +341,10 @@ impl Prompt {
current_length += sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message);
tiktoken_messages.push(new_message);
}
SubPrompt::Content(SubPromptType::UserLastMessage, content, _) => {
last_user_message = Some(content.clone());
}
_ => {
// If we were processing ExtraContext, add it as a single System message
if processing_extra_context {
let extra_context_message = LlmMessage {
role: Some(SubPromptType::User.to_string()),
content: Some(extra_context_content.trim().to_string()),
name: None,
function_call: None,
functions: None,
};
current_length +=
ModelCapabilitiesManager::num_tokens_from_llama3(&[extra_context_message.clone()]);
tiktoken_messages.push(extra_context_message);

// Reset the accumulator
extra_context_content.clear();
processing_extra_context = false;
}

// Process the current sub-prompt
let new_message = sub_prompt.into_chat_completion_request_message();
current_length += sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message);
Expand All @@ -367,17 +353,25 @@ impl Prompt {
}
}

// If there are any remaining ExtraContext sub-prompts, add them as a single message
if processing_extra_context && !extra_context_content.is_empty() {
let extra_context_message = LlmMessage {
// Combine ExtraContext and UserLastMessage into one message
if !extra_context_content.is_empty() || last_user_message.is_some() {
let combined_content = format!(
"{}\n{}",
extra_context_content.trim(),
last_user_message.unwrap_or_default()
)
.trim()
.to_string();

let combined_message = LlmMessage {
role: Some(SubPromptType::User.to_string()),
content: Some(extra_context_content.trim().to_string()),
content: Some(combined_content),
name: None,
function_call: None,
functions: None,
};
current_length += ModelCapabilitiesManager::num_tokens_from_llama3(&[extra_context_message.clone()]);
tiktoken_messages.push(extra_context_message);
current_length += ModelCapabilitiesManager::num_tokens_from_llama3(&[combined_message.clone()]);
tiktoken_messages.push(combined_message);
}

(tiktoken_messages, current_length)
Expand Down Expand Up @@ -525,11 +519,6 @@ mod tests {

let (messages, _token_length) = prompt.generate_chat_completion_messages();

// match serde_json::to_string_pretty(&messages) {
// Ok(pretty_json) => eprintln!("messages JSON: {}", pretty_json),
// Err(e) => eprintln!("Failed to serialize tools_json: {:?}", e),
// };

// Expected messages
let expected_messages = vec![
LlmMessage {
Expand All @@ -553,13 +542,6 @@ mod tests {
function_call: None,
functions: None,
},
LlmMessage {
role: Some("user".to_string()),
content: Some("Here is a list of relevant new content provided for you to potentially use while answering:\n- FAQ Shinkai Overview What’s Shinkai? (Summary) (Source: Shinkai - Ask Me Anything.docx, Section: ) 2024-05-05T00:33:00\n- Shinkai is a comprehensive super app designed to enhance how users interact with AI. It allows users to run AI locally, facilitating direct conversations with documents and managing files converted into AI embeddings for advanced semantic searches across user data. This local execution ensures privacy and efficiency, putting control directly in the user's hands. (Source: Shinkai - Ask Me Anything.docx, Section: 2) 2024-05-05T00:33:00".to_string()),
name: None,
function_call: None,
functions: None,
},
LlmMessage {
role: Some("user".to_string()),
content: Some("tell me more about Shinkai. Answer the question using this markdown and the extra context provided: \n # Answer \n here goes the answer\n".to_string()),
Expand Down Expand Up @@ -602,6 +584,13 @@ mod tests {
},
}]),
},
LlmMessage {
role: Some("user".to_string()),
content: Some("Here is a list of relevant new content provided for you to potentially use while answering:\n- FAQ Shinkai Overview What’s Shinkai? (Summary) (Source: Shinkai - Ask Me Anything.docx, Section: ) 2024-05-05T00:33:00\n- Shinkai is a comprehensive super app designed to enhance how users interact with AI. It allows users to run AI locally, facilitating direct conversations with documents and managing files converted into AI embeddings for advanced semantic searches across user data. This local execution ensures privacy and efficiency, putting control directly in the user's hands. (Source: Shinkai - Ask Me Anything.docx, Section: 2) 2024-05-05T00:33:00".to_string()),
name: None,
function_call: None,
functions: None,
},
];

// Check if the generated messages match the expected messages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::prompts::Prompt;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SubPromptType {
User,
UserLastMessage,
System,
Assistant,
ExtraContext,
Expand All @@ -23,6 +24,7 @@ impl fmt::Display for SubPromptType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
SubPromptType::User => "user",
SubPromptType::UserLastMessage => "user",
SubPromptType::System => "system",
SubPromptType::Assistant => "assistant",
SubPromptType::ExtraContext => "user",
Expand Down
39 changes: 37 additions & 2 deletions shinkai-bin/shinkai-node/src/network/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ use chrono::Utc;
use core::panic;
use ed25519_dalek::{Signer, SigningKey, VerifyingKey};
use futures::{future::FutureExt, pin_mut, prelude::*, select};
use rand::Rng;
use rand::rngs::OsRng;
use rand::{Rng, RngCore};
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::SerializedLLMProvider;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::shinkai_network::NetworkMessageType;
Expand Down Expand Up @@ -139,6 +140,8 @@ pub struct Node {
pub default_embedding_model: Arc<Mutex<EmbeddingModelType>>,
// Supported embedding models for profiles
pub supported_embedding_models: Arc<Mutex<Vec<EmbeddingModelType>>>,
// API V2 Key
pub api_v2_key: String,
}

impl Node {
Expand All @@ -163,6 +166,7 @@ impl Node {
ws_address: Option<SocketAddr>,
default_embedding_model: EmbeddingModelType,
supported_embedding_models: Vec<EmbeddingModelType>,
api_v2_key: Option<String>,
) -> Arc<Mutex<Node>> {
// if is_valid_node_identity_name_and_no_subidentities is false panic
match ShinkaiName::new(node_name.to_string().clone()) {
Expand Down Expand Up @@ -347,6 +351,25 @@ impl Node {
.await;
let sheet_manager = sheet_manager_result.unwrap();

// It reads the api_v2_key from env, if not from db and if not, then it generates a new one that gets saved in the db
let api_v2_key = if let Some(key) = api_v2_key {
db_arc
.set_api_v2_key(&key)
.expect("Failed to set api_v2_key in the database");
key
} else {
match db_arc.read_api_v2_key() {
Ok(Some(key)) => key,
Ok(None) | Err(_) => {
let new_key = Node::generate_api_v2_key();
db_arc
.set_api_v2_key(&new_key)
.expect("Failed to set api_v2_key in the database");
new_key
}
}
};

Arc::new(Mutex::new(Node {
node_name: node_name.clone(),
identity_secret_key: clone_signature_secret_key(&identity_secret_key),
Expand Down Expand Up @@ -383,6 +406,7 @@ impl Node {
tool_router: Some(Arc::new(Mutex::new(tool_router))),
default_embedding_model,
supported_embedding_models,
api_v2_key,
}))
}

Expand Down Expand Up @@ -460,7 +484,12 @@ impl Node {
let reinstall_tools = std::env::var("REINSTALL_TOOLS").unwrap_or_else(|_| "false".to_string()) == "true";

tokio::spawn(async move {
let current_version = tool_router.lock().await.get_current_lancedb_version().await.unwrap_or(None);
let current_version = tool_router
.lock()
.await
.get_current_lancedb_version()
.await
.unwrap_or(None);
if reinstall_tools || current_version != Some(LATEST_ROUTER_DB_VERSION.to_string()) {
if let Err(e) = tool_router.lock().await.force_reinstall_all(generator).await {
eprintln!("ToolRouter force reinstall failed: {:?}", e);
Expand Down Expand Up @@ -1457,6 +1486,12 @@ impl Node {
Err(e) => eprintln!("Failed to read validation data: {}", e),
}
}

fn generate_api_v2_key() -> String {
let mut key = [0u8; 32]; // 256-bit key
OsRng.fill_bytes(&mut key);
base64::encode(&key)
}
}

impl Drop for Node {
Expand Down
25 changes: 9 additions & 16 deletions shinkai-bin/shinkai-node/src/network/node_api_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,17 @@ pub async fn run_api(
);
println!("API server running on http://{}", address);

if env::var("API_V2_KEY").is_ok() {
let v2_routes = warp::path("v2").and(
v2_routes(node_commands_sender.clone(), node_name.clone())
.recover(handle_rejection)
.with(log)
.with(cors.clone()),
);

// Combine all routes
let routes = v1_routes.or(v2_routes).with(log).with(cors);
let v2_routes = warp::path("v2").and(
v2_routes(node_commands_sender.clone(), node_name.clone())
.recover(handle_rejection)
.with(log)
.with(cors.clone()),
);

warp::serve(routes).run(address).await;
} else {
// Combine all routes
let routes = v1_routes.with(log).with(cors);
// Combine all routes
let routes = v1_routes.or(v2_routes).with(log).with(cors);

warp::serve(routes).run(address).await;
}
warp::serve(routes).run(address).await;

Ok(())
}
Expand Down
29 changes: 29 additions & 0 deletions shinkai-bin/shinkai-node/src/network/v1_api/api_v1_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,20 @@ impl Node {
identity_type: standard_identity_type,
permission_type,
};

let api_v2_key = match db.read_api_v2_key() {
Ok(Some(api_key)) => api_key,
Ok(None) | Err(_) => {
let api_error = APIError {
code: StatusCode::UNAUTHORIZED.as_u16(),
error: "Unauthorized".to_string(),
message: "Invalid bearer token".to_string(),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

let mut subidentity_manager = identity_manager.lock().await;
match subidentity_manager.add_profile_subidentity(subidentity).await {
Ok(_) => {
Expand All @@ -900,6 +914,7 @@ impl Node {
node_name: node_name.get_node_name_string().clone(),
encryption_public_key: encryption_public_key_to_string(encryption_public_key),
identity_public_key: signature_public_key_to_string(identity_public_key),
api_v2_key
};
let _ = res.send(Ok(success_response)).await.map_err(|_| ());
}
Expand Down Expand Up @@ -984,6 +999,19 @@ impl Node {
permission_type,
};

let api_v2_key = match db.read_api_v2_key() {
Ok(Some(api_key)) => api_key,
Ok(None) | Err(_) => {
let api_error = APIError {
code: StatusCode::UNAUTHORIZED.as_u16(),
error: "Unauthorized".to_string(),
message: "Invalid bearer token".to_string(),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

let mut identity_manager_mut = identity_manager.lock().await;
match identity_manager_mut.add_device_subidentity(device_identity).await {
Ok(_) => {
Expand All @@ -1009,6 +1037,7 @@ impl Node {
node_name: node_name.get_node_name_string().clone(),
encryption_public_key: encryption_public_key_to_string(encryption_public_key),
identity_public_key: signature_public_key_to_string(identity_public_key),
api_v2_key,
};
let _ = res.send(Ok(success_response)).await.map_err(|_| ());
}
Expand Down
Loading

0 comments on commit 53f1117

Please sign in to comment.